auth0-ai-langchain 0.2.0__py3-none-any.whl → 1.0.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auth0-ai-langchain might be problematic. Click here for more details.
- auth0_ai_langchain/FGARetriever.py +3 -3
- auth0_ai_langchain/auth0_ai.py +101 -31
- auth0_ai_langchain/ciba/__init__.py +3 -0
- auth0_ai_langchain/ciba/ciba_authorizer.py +17 -0
- auth0_ai_langchain/ciba/graph_resumer.py +154 -0
- auth0_ai_langchain/federated_connections/__init__.py +9 -3
- auth0_ai_langchain/federated_connections/federated_connection_authorizer.py +14 -33
- auth0_ai_langchain/fga/__init__.py +4 -0
- auth0_ai_langchain/utils/interrupt.py +18 -1
- auth0_ai_langchain/utils/tool_wrapper.py +34 -0
- auth0_ai_langchain-1.0.0b2.dist-info/METADATA +352 -0
- auth0_ai_langchain-1.0.0b2.dist-info/RECORD +15 -0
- auth0_ai_langchain/ciba/ciba_graph/ciba_graph.py +0 -109
- auth0_ai_langchain/ciba/ciba_graph/initialize_ciba.py +0 -91
- auth0_ai_langchain/ciba/ciba_graph/initialize_hitl.py +0 -50
- auth0_ai_langchain/ciba/ciba_graph/types.py +0 -115
- auth0_ai_langchain/ciba/ciba_graph/utils.py +0 -17
- auth0_ai_langchain/ciba/ciba_poller_graph.py +0 -105
- auth0_ai_langchain/ciba/types.py +0 -8
- auth0_ai_langchain/fga/fga_authorizer.py +0 -3
- auth0_ai_langchain-0.2.0.dist-info/METADATA +0 -221
- auth0_ai_langchain-0.2.0.dist-info/RECORD +0 -19
- {auth0_ai_langchain-0.2.0.dist-info → auth0_ai_langchain-1.0.0b2.dist-info}/LICENSE +0 -0
- {auth0_ai_langchain-0.2.0.dist-info → auth0_ai_langchain-1.0.0b2.dist-info}/WHEEL +0 -0
|
@@ -32,7 +32,7 @@ class FGARetriever(BaseRetriever):
|
|
|
32
32
|
Args:
|
|
33
33
|
retriever (BaseRetriever): The retriever used to fetch documents.
|
|
34
34
|
build_query (Callable[[Document], ClientBatchCheckItem]): Function to convert documents into FGA queries.
|
|
35
|
-
fga_configuration (
|
|
35
|
+
fga_configuration (ClientConfiguration, optional): Configuration for the OpenFGA client. If not provided, defaults to environment variables.
|
|
36
36
|
"""
|
|
37
37
|
super().__init__()
|
|
38
38
|
self._retriever = retriever
|
|
@@ -95,7 +95,7 @@ class FGARetriever(BaseRetriever):
|
|
|
95
95
|
|
|
96
96
|
Args:
|
|
97
97
|
query (str): The query for retrieving documents.
|
|
98
|
-
run_manager (
|
|
98
|
+
run_manager (object, optional): Optional manager for tracking runs.
|
|
99
99
|
|
|
100
100
|
Returns:
|
|
101
101
|
List[Document]: Filtered and relevant documents.
|
|
@@ -148,7 +148,7 @@ class FGARetriever(BaseRetriever):
|
|
|
148
148
|
|
|
149
149
|
Args:
|
|
150
150
|
query (str): The query for retrieving documents.
|
|
151
|
-
run_manager (
|
|
151
|
+
run_manager (object, optional): Optional manager for tracking runs.
|
|
152
152
|
|
|
153
153
|
Returns:
|
|
154
154
|
List[Document]: Filtered and relevant documents.
|
auth0_ai_langchain/auth0_ai.py
CHANGED
|
@@ -1,43 +1,113 @@
|
|
|
1
1
|
from typing import Callable, Optional
|
|
2
|
-
from langchain_core.runnables.config import RunnableConfig
|
|
3
2
|
from langchain_core.tools import BaseTool
|
|
4
|
-
from auth0_ai.
|
|
5
|
-
from auth0_ai.authorizers.
|
|
6
|
-
from auth0_ai.authorizers.
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def get_access_token(config: RunnableConfig) -> Credential:
|
|
12
|
-
"""
|
|
13
|
-
Fetch the access token obtained during the CIBA flow.
|
|
3
|
+
from auth0_ai.authorizers.ciba import CIBAAuthorizerParams
|
|
4
|
+
from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerParams
|
|
5
|
+
from auth0_ai.authorizers.types import Auth0ClientParams
|
|
6
|
+
from auth0_ai_langchain.ciba.ciba_authorizer import CIBAAuthorizer
|
|
7
|
+
from auth0_ai_langchain.federated_connections.federated_connection_authorizer import FederatedConnectionAuthorizer
|
|
8
|
+
|
|
14
9
|
|
|
15
|
-
|
|
16
|
-
|
|
10
|
+
class Auth0AI:
|
|
11
|
+
"""Provides decorators to secure LangChain tools using Auth0 authorization flows.
|
|
17
12
|
"""
|
|
18
|
-
return config.get("configurable", {}).get("_credentials", {}).get("access_token")
|
|
19
13
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
self._graph: Optional[CIBAGraph] = None
|
|
23
|
-
self.config = config
|
|
14
|
+
def __init__(self, auth0: Optional[Auth0ClientParams] = None):
|
|
15
|
+
"""Initializes the Auth0AI instance.
|
|
24
16
|
|
|
25
|
-
|
|
17
|
+
Args:
|
|
18
|
+
auth0 (Optional[Auth0ClientParams]): Parameters for the Auth0 client.
|
|
19
|
+
If not provided, values will be automatically read from environment
|
|
20
|
+
variables: `AUTH0_DOMAIN`, `AUTH0_CLIENT_ID`, and `AUTH0_CLIENT_SECRET`.
|
|
26
21
|
"""
|
|
27
|
-
|
|
22
|
+
self.auth0 = auth0
|
|
28
23
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
24
|
+
def with_async_user_confirmation(self, **params: CIBAAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
|
|
25
|
+
"""Protects a tool with the CIBA (Client-Initiated Backchannel Authentication) flow.
|
|
26
|
+
|
|
27
|
+
Requires user confirmation via a second device (e.g., phone)
|
|
28
|
+
before allowing the tool to execute.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
**params: Parameters defined in `CIBAAuthorizerParams`.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Callable[[BaseTool], BaseTool]: A decorator to wrap a LangChain tool.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
```python
|
|
38
|
+
import os
|
|
39
|
+
from auth0_ai_langchain.auth0_ai import Auth0AI
|
|
40
|
+
from auth0_ai_langchain.ciba import get_ciba_credentials
|
|
41
|
+
from langchain_core.runnables import ensure_config
|
|
42
|
+
from langchain_core.tools import StructuredTool
|
|
43
|
+
|
|
44
|
+
auth0_ai = Auth0AI()
|
|
45
|
+
|
|
46
|
+
with_async_user_confirmation = auth0_ai.with_async_user_confirmation(
|
|
47
|
+
scopes=["stock:trade"],
|
|
48
|
+
audience=os.getenv("AUDIENCE"),
|
|
49
|
+
binding_message=lambda ticker, qty: f"Authorize the purchase of {qty} {ticker}",
|
|
50
|
+
user_id=lambda *_, **__: ensure_config().get("configurable", {}).get("user_id")
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def tool_function(ticker: str, qty: int) -> str:
|
|
54
|
+
credentials = get_ciba_credentials()
|
|
55
|
+
headers = {
|
|
56
|
+
"Authorization": f"{credentials['token_type']} {credentials['access_token']}",
|
|
57
|
+
# ...
|
|
58
|
+
}
|
|
59
|
+
# Call API
|
|
60
|
+
|
|
61
|
+
trade_tool = with_async_user_confirmation(
|
|
62
|
+
StructuredTool(
|
|
63
|
+
name="trade_tool",
|
|
64
|
+
description="Use this function to trade a stock",
|
|
65
|
+
func=tool_function,
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
```
|
|
36
69
|
"""
|
|
37
|
-
|
|
70
|
+
authorizer = CIBAAuthorizer(CIBAAuthorizerParams(**params), self.auth0)
|
|
71
|
+
return authorizer.authorizer()
|
|
72
|
+
|
|
73
|
+
def with_federated_connection(self, **params: FederatedConnectionAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
|
|
74
|
+
"""Enables a tool to obtain an access token from a federated identity provider (e.g., Google, Azure AD).
|
|
75
|
+
|
|
76
|
+
The token can then be used within the tool to call third-party APIs on behalf of the user.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
**params: Parameters defined in `FederatedConnectionAuthorizerParams`.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Callable[[BaseTool], BaseTool]: A decorator to wrap a LangChain tool.
|
|
83
|
+
|
|
84
|
+
Example:
|
|
85
|
+
```python
|
|
86
|
+
from auth0_ai_langchain.auth0_ai import Auth0AI
|
|
87
|
+
from auth0_ai_langchain.federated_connections import get_credentials_for_connection
|
|
88
|
+
from langchain_core.tools import StructuredTool
|
|
89
|
+
from datetime import datetime
|
|
90
|
+
|
|
91
|
+
auth0_ai = Auth0AI()
|
|
92
|
+
|
|
93
|
+
with_google_calendar_access = auth0_ai.with_federated_connection(
|
|
94
|
+
connection="google-oauth2",
|
|
95
|
+
scopes=["https://www.googleapis.com/auth/calendar.freebusy"]
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def tool_function(date: datetime):
|
|
99
|
+
credentials = get_credentials_for_connection()
|
|
100
|
+
# Call Google API using credentials["access_token"]
|
|
38
101
|
|
|
39
|
-
|
|
40
|
-
|
|
102
|
+
check_calendar_tool = with_google_calendar_access(
|
|
103
|
+
StructuredTool(
|
|
104
|
+
name="check_user_calendar",
|
|
105
|
+
description="Use this function to check if the user is available on a certain date and time",
|
|
106
|
+
func=tool_function,
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
```
|
|
41
110
|
"""
|
|
42
|
-
authorizer = FederatedConnectionAuthorizer(
|
|
111
|
+
authorizer = FederatedConnectionAuthorizer(
|
|
112
|
+
FederatedConnectionAuthorizerParams(**params), self.auth0)
|
|
43
113
|
return authorizer.authorizer()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import Union
|
|
3
|
+
from auth0_ai.authorizers.ciba import CIBAAuthorizerBase
|
|
4
|
+
from auth0_ai.interrupts.ciba_interrupts import AuthorizationPendingInterrupt, AuthorizationPollingInterrupt
|
|
5
|
+
from auth0_ai_langchain.utils.interrupt import to_graph_interrupt
|
|
6
|
+
from auth0_ai_langchain.utils.tool_wrapper import tool_wrapper
|
|
7
|
+
from langchain_core.tools import BaseTool
|
|
8
|
+
|
|
9
|
+
class CIBAAuthorizer(CIBAAuthorizerBase, ABC):
|
|
10
|
+
def _handle_authorization_interrupts(self, err: Union[AuthorizationPendingInterrupt, AuthorizationPollingInterrupt]) -> None:
|
|
11
|
+
raise to_graph_interrupt(err)
|
|
12
|
+
|
|
13
|
+
def authorizer(self):
|
|
14
|
+
def wrap_tool(tool: BaseTool) -> BaseTool:
|
|
15
|
+
return tool_wrapper(tool, self.protect)
|
|
16
|
+
|
|
17
|
+
return wrap_tool
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from threading import Event
|
|
3
|
+
from typing import Callable, Optional, Dict, Any, List, TypedDict
|
|
4
|
+
from auth0_ai.authorizers.ciba import CIBAAuthorizationRequest
|
|
5
|
+
from auth0_ai.interrupts.ciba_interrupts import CIBAInterrupt, AuthorizationPendingInterrupt, AuthorizationPollingInterrupt
|
|
6
|
+
from auth0_ai_langchain.utils.interrupt import get_auth0_interrupts
|
|
7
|
+
from langgraph_sdk.client import LangGraphClient
|
|
8
|
+
from langgraph_sdk.schema import Thread, Interrupt
|
|
9
|
+
|
|
10
|
+
class WatchedThread(TypedDict):
|
|
11
|
+
thread_id: str
|
|
12
|
+
assistant_id: str
|
|
13
|
+
interruption_id: str
|
|
14
|
+
auth_request: CIBAAuthorizationRequest
|
|
15
|
+
config: Dict[str, Any]
|
|
16
|
+
last_run: float
|
|
17
|
+
|
|
18
|
+
class GraphResumerFilters(TypedDict):
|
|
19
|
+
graph_id: str
|
|
20
|
+
|
|
21
|
+
class GraphResumer:
|
|
22
|
+
def __init__(self, lang_graph: LangGraphClient, filters: Optional[GraphResumerFilters] = None):
|
|
23
|
+
self.lang_graph = lang_graph
|
|
24
|
+
self.filters = filters or {}
|
|
25
|
+
self.map: Dict[str, WatchedThread] = {}
|
|
26
|
+
self._stop_event = Event()
|
|
27
|
+
self._loop_task: Optional[asyncio.Task] = None
|
|
28
|
+
|
|
29
|
+
# Event callbacks
|
|
30
|
+
self._resume_callbacks: List[Callable[[WatchedThread], None]] = []
|
|
31
|
+
self._error_callbacks: List[Callable[[Exception], None]] = []
|
|
32
|
+
|
|
33
|
+
# Public API to register event callbacks
|
|
34
|
+
def on_resume(self, callback: Callable[[WatchedThread], None]) -> "GraphResumer":
|
|
35
|
+
self._resume_callbacks.append(callback)
|
|
36
|
+
return self
|
|
37
|
+
|
|
38
|
+
def on_error(self, callback: Callable[[Exception], None]) -> "GraphResumer":
|
|
39
|
+
self._error_callbacks.append(callback)
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def _emit_resume(self, thread: WatchedThread) -> None:
|
|
43
|
+
for callback in self._resume_callbacks:
|
|
44
|
+
callback(thread)
|
|
45
|
+
|
|
46
|
+
def _emit_error(self, error: Exception) -> None:
|
|
47
|
+
for callback in self._error_callbacks:
|
|
48
|
+
callback(error)
|
|
49
|
+
|
|
50
|
+
async def _get_all_interrupted_threads(self) -> List[Thread]:
|
|
51
|
+
interrupted_threads: List[Thread] = []
|
|
52
|
+
offset = 0
|
|
53
|
+
|
|
54
|
+
while True:
|
|
55
|
+
page = await self.lang_graph.threads.search(
|
|
56
|
+
status="interrupted",
|
|
57
|
+
limit=100,
|
|
58
|
+
offset=offset,
|
|
59
|
+
metadata={"graph_id": self.filters["graph_id"]} if "graph_id" in self.filters else None
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if not page:
|
|
63
|
+
break
|
|
64
|
+
|
|
65
|
+
for t in page:
|
|
66
|
+
interrupt = self._get_first_interrupt(t)
|
|
67
|
+
if interrupt and CIBAInterrupt.is_interrupt(interrupt["value"]) and CIBAInterrupt.has_request_data(interrupt["value"]):
|
|
68
|
+
interrupted_threads.append(t)
|
|
69
|
+
|
|
70
|
+
offset += len(page)
|
|
71
|
+
if len(page) < 100:
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
return interrupted_threads
|
|
75
|
+
|
|
76
|
+
def _get_first_interrupt(self, thread: Thread) -> Optional[Interrupt]:
|
|
77
|
+
interrupts = thread["interrupts"]
|
|
78
|
+
if interrupts:
|
|
79
|
+
values = list(interrupts.values())
|
|
80
|
+
if values and values[0]:
|
|
81
|
+
return values[0][0]
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
def _get_hash_map_id(self, thread: Thread) -> str:
|
|
85
|
+
return f"{thread['thread_id']}:{next(iter(thread['interrupts']))}"
|
|
86
|
+
|
|
87
|
+
async def _resume_thread(self, t: WatchedThread):
|
|
88
|
+
self._emit_resume(t)
|
|
89
|
+
|
|
90
|
+
await self.lang_graph.runs.wait(t["thread_id"], t["assistant_id"], config=t["config"])
|
|
91
|
+
|
|
92
|
+
t["last_run"] = asyncio.get_event_loop().time() * 1000
|
|
93
|
+
|
|
94
|
+
async def loop(self):
|
|
95
|
+
all_threads = await self._get_all_interrupted_threads()
|
|
96
|
+
|
|
97
|
+
# Remove old interrupted threads
|
|
98
|
+
active_keys = {self._get_hash_map_id(t) for t in all_threads}
|
|
99
|
+
|
|
100
|
+
for key in list(self.map.keys()):
|
|
101
|
+
if key not in active_keys:
|
|
102
|
+
del self.map[key]
|
|
103
|
+
|
|
104
|
+
# Add new interrupted threads
|
|
105
|
+
for thread in all_threads:
|
|
106
|
+
interrupt = next(
|
|
107
|
+
(i for i in get_auth0_interrupts(thread)
|
|
108
|
+
if AuthorizationPendingInterrupt.is_interrupt(i["value"])
|
|
109
|
+
or AuthorizationPollingInterrupt.is_interrupt(i["value"])),
|
|
110
|
+
None
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if not interrupt or not interrupt["value"].get("_request"):
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
key = self._get_hash_map_id(thread)
|
|
117
|
+
if key not in self.map:
|
|
118
|
+
self.map[key] = {
|
|
119
|
+
"thread_id": thread["thread_id"],
|
|
120
|
+
"assistant_id": thread["metadata"].get("graph_id"),
|
|
121
|
+
"config": getattr(thread, "config", {}),
|
|
122
|
+
"interruption_id": next(iter(thread["interrupts"])),
|
|
123
|
+
"auth_request": interrupt["value"]["_request"],
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
threads_to_resume = [
|
|
127
|
+
t for t in self.map.values()
|
|
128
|
+
if "last_run" not in t or (t["last_run"] + t["auth_request"]["interval"] * 1000 < asyncio.get_event_loop().time() * 1000)
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
await asyncio.gather(*[
|
|
132
|
+
self._resume_thread(t) for t in threads_to_resume
|
|
133
|
+
])
|
|
134
|
+
|
|
135
|
+
def start(self):
|
|
136
|
+
if self._loop_task and not self._loop_task.done():
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
self._stop_event.clear()
|
|
140
|
+
|
|
141
|
+
async def _run_loop():
|
|
142
|
+
while not self._stop_event.is_set():
|
|
143
|
+
try:
|
|
144
|
+
await self.loop()
|
|
145
|
+
except Exception as e:
|
|
146
|
+
self._emit_error(e)
|
|
147
|
+
await asyncio.sleep(5)
|
|
148
|
+
|
|
149
|
+
self._loop_task = asyncio.create_task(_run_loop())
|
|
150
|
+
|
|
151
|
+
def stop(self):
|
|
152
|
+
self._stop_event.set()
|
|
153
|
+
if self._loop_task:
|
|
154
|
+
self._loop_task.cancel()
|
|
@@ -1,4 +1,10 @@
|
|
|
1
|
-
from auth0_ai.interrupts.federated_connection_interrupt import
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
from auth0_ai.interrupts.federated_connection_interrupt import (
|
|
2
|
+
FederatedConnectionError as FederatedConnectionError,
|
|
3
|
+
FederatedConnectionInterrupt as FederatedConnectionInterrupt
|
|
4
|
+
)
|
|
5
|
+
|
|
6
|
+
from auth0_ai.authorizers.federated_connection_authorizer import (
|
|
7
|
+
get_credentials_for_connection as get_credentials_for_connection,
|
|
8
|
+
get_access_token_for_connection as get_access_token_for_connection
|
|
9
|
+
)
|
|
4
10
|
from .federated_connection_authorizer import FederatedConnectionAuthorizer as FederatedConnectionAuthorizer
|
|
@@ -1,52 +1,33 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
from abc import ABC
|
|
3
3
|
from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerBase, FederatedConnectionAuthorizerParams
|
|
4
|
-
from auth0_ai.authorizers.types import
|
|
4
|
+
from auth0_ai.authorizers.types import Auth0ClientParams
|
|
5
5
|
from auth0_ai.interrupts.federated_connection_interrupt import FederatedConnectionInterrupt
|
|
6
|
-
from
|
|
6
|
+
from auth0_ai_langchain.utils.interrupt import to_graph_interrupt
|
|
7
|
+
from auth0_ai_langchain.utils.tool_wrapper import tool_wrapper
|
|
8
|
+
from langchain_core.tools import BaseTool
|
|
7
9
|
from langchain_core.runnables import ensure_config
|
|
8
|
-
from ..utils.interrupt import to_graph_interrupt
|
|
9
10
|
|
|
10
|
-
async def
|
|
11
|
+
async def default_get_refresh_token(*_, **__) -> str | None:
|
|
11
12
|
return ensure_config().get("configurable", {}).get("_credentials", {}).get("refresh_token")
|
|
12
13
|
|
|
13
14
|
class FederatedConnectionAuthorizer(FederatedConnectionAuthorizerBase, ABC):
|
|
14
15
|
def __init__(
|
|
15
16
|
self,
|
|
16
|
-
|
|
17
|
-
|
|
17
|
+
params: FederatedConnectionAuthorizerParams,
|
|
18
|
+
auth0: Auth0ClientParams = None,
|
|
18
19
|
):
|
|
19
|
-
if
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
if params.refresh_token.value is None:
|
|
21
|
+
params = copy.copy(params)
|
|
22
|
+
params.refresh_token.value = default_get_refresh_token
|
|
22
23
|
|
|
23
|
-
super().__init__(
|
|
24
|
+
super().__init__(params, auth0)
|
|
24
25
|
|
|
25
26
|
def _handle_authorization_interrupts(self, err: FederatedConnectionInterrupt) -> None:
|
|
26
27
|
raise to_graph_interrupt(err)
|
|
27
28
|
|
|
28
29
|
def authorizer(self):
|
|
29
|
-
def
|
|
30
|
-
|
|
31
|
-
return await t.ainvoke(input=kwargs)
|
|
32
|
-
|
|
33
|
-
tool_fn = self.protect(
|
|
34
|
-
lambda *_args, **_kwargs: {
|
|
35
|
-
"thread_id": ensure_config().get("configurable", {}).get("thread_id"),
|
|
36
|
-
"checkpoint_ns": ensure_config().get("configurable", {}).get("checkpoint_ns"),
|
|
37
|
-
"run_id": ensure_config().get("configurable", {}).get("run_id"),
|
|
38
|
-
"tool_call_id": ensure_config().get("configurable", {}).get("tool_call_id"), # TODO: review this
|
|
39
|
-
},
|
|
40
|
-
execute_fn
|
|
41
|
-
)
|
|
42
|
-
tool_fn.__name__ = t.name
|
|
43
|
-
|
|
44
|
-
return tool(
|
|
45
|
-
tool_fn,
|
|
46
|
-
description=t.description,
|
|
47
|
-
return_direct=t.return_direct,
|
|
48
|
-
args_schema=t.args_schema,
|
|
49
|
-
response_format=t.response_format,
|
|
50
|
-
)
|
|
30
|
+
def wrap_tool(tool: BaseTool) -> BaseTool:
|
|
31
|
+
return tool_wrapper(tool, self.protect)
|
|
51
32
|
|
|
52
|
-
return
|
|
33
|
+
return wrap_tool
|
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
from typing import List
|
|
1
2
|
from auth0_ai.interrupts.auth0_interrupt import Auth0Interrupt
|
|
2
3
|
from langgraph.errors import GraphInterrupt
|
|
3
4
|
from langgraph.types import Interrupt
|
|
5
|
+
from langgraph_sdk.schema import Thread
|
|
6
|
+
|
|
4
7
|
|
|
5
8
|
def to_graph_interrupt(interrupt: Auth0Interrupt) -> GraphInterrupt:
|
|
6
9
|
return GraphInterrupt([
|
|
@@ -8,6 +11,20 @@ def to_graph_interrupt(interrupt: Auth0Interrupt) -> GraphInterrupt:
|
|
|
8
11
|
value=interrupt.to_json(),
|
|
9
12
|
when="during",
|
|
10
13
|
resumable=True,
|
|
11
|
-
ns=[f"auth0AI:{interrupt.
|
|
14
|
+
ns=[f"auth0AI:{interrupt.name}:{interrupt.code}"]
|
|
12
15
|
)
|
|
13
16
|
])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_auth0_interrupts(thread: Thread) -> List[Interrupt]:
|
|
20
|
+
result = []
|
|
21
|
+
|
|
22
|
+
if "interrupts" not in thread:
|
|
23
|
+
return result
|
|
24
|
+
|
|
25
|
+
for interrupt_list in thread["interrupts"].values():
|
|
26
|
+
for interrupt in interrupt_list:
|
|
27
|
+
if Auth0Interrupt.is_interrupt(interrupt["value"]):
|
|
28
|
+
result.append(interrupt)
|
|
29
|
+
|
|
30
|
+
return result
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
from typing_extensions import Annotated
|
|
3
|
+
from pydantic import create_model
|
|
4
|
+
from langchain_core.tools import BaseTool, tool as create_tool, InjectedToolCallId
|
|
5
|
+
from langchain_core.runnables import RunnableConfig
|
|
6
|
+
|
|
7
|
+
def tool_wrapper(tool: BaseTool, protect_fn: Callable) -> BaseTool:
|
|
8
|
+
|
|
9
|
+
# Workaround: extend existing args_schema to be able to get the tool_call_id value
|
|
10
|
+
args_schema = create_model(
|
|
11
|
+
tool.args_schema.__name__ + "Extended",
|
|
12
|
+
__base__=tool.args_schema,
|
|
13
|
+
**{"tool_call_id": (Annotated[str, InjectedToolCallId])}
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
@create_tool(
|
|
17
|
+
tool.name,
|
|
18
|
+
description=tool.description,
|
|
19
|
+
args_schema=args_schema
|
|
20
|
+
)
|
|
21
|
+
async def wrapped_tool(config: RunnableConfig, tool_call_id: Annotated[str, InjectedToolCallId], **input):
|
|
22
|
+
async def execute_fn(*_, **__):
|
|
23
|
+
return await tool.ainvoke(input, config)
|
|
24
|
+
|
|
25
|
+
return await protect_fn(
|
|
26
|
+
lambda *_, **__: {
|
|
27
|
+
"thread_id": config.get("configurable", {}).get("thread_id"),
|
|
28
|
+
"tool_call_id": tool_call_id,
|
|
29
|
+
"tool_name": tool.name,
|
|
30
|
+
},
|
|
31
|
+
execute_fn,
|
|
32
|
+
)(**input)
|
|
33
|
+
|
|
34
|
+
return wrapped_tool
|