auth0-ai-langchain 0.1.2__py3-none-any.whl → 1.0.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auth0-ai-langchain might be problematic. Click here for more details.
- auth0_ai_langchain/FGARetriever.py +3 -3
- auth0_ai_langchain/auth0_ai.py +100 -31
- auth0_ai_langchain/ciba/__init__.py +3 -0
- auth0_ai_langchain/ciba/ciba_authorizer.py +17 -0
- auth0_ai_langchain/ciba/graph_resumer.py +154 -0
- auth0_ai_langchain/federated_connections/__init__.py +6 -2
- auth0_ai_langchain/federated_connections/federated_connection_authorizer.py +14 -33
- auth0_ai_langchain/fga/__init__.py +4 -0
- auth0_ai_langchain/utils/interrupt.py +23 -6
- auth0_ai_langchain/utils/tool_wrapper.py +34 -0
- {auth0_ai_langchain-0.1.2.dist-info → auth0_ai_langchain-1.0.0b1.dist-info}/METADATA +127 -27
- auth0_ai_langchain-1.0.0b1.dist-info/RECORD +15 -0
- auth0_ai_langchain/ciba/ciba_graph/ciba_graph.py +0 -109
- auth0_ai_langchain/ciba/ciba_graph/initialize_ciba.py +0 -91
- auth0_ai_langchain/ciba/ciba_graph/initialize_hitl.py +0 -50
- auth0_ai_langchain/ciba/ciba_graph/types.py +0 -115
- auth0_ai_langchain/ciba/ciba_graph/utils.py +0 -17
- auth0_ai_langchain/ciba/ciba_poller_graph.py +0 -105
- auth0_ai_langchain/ciba/types.py +0 -8
- auth0_ai_langchain/fga/fga_authorizer.py +0 -3
- auth0_ai_langchain-0.1.2.dist-info/RECORD +0 -19
- {auth0_ai_langchain-0.1.2.dist-info → auth0_ai_langchain-1.0.0b1.dist-info}/LICENSE +0 -0
- {auth0_ai_langchain-0.1.2.dist-info → auth0_ai_langchain-1.0.0b1.dist-info}/WHEEL +0 -0
|
@@ -32,7 +32,7 @@ class FGARetriever(BaseRetriever):
|
|
|
32
32
|
Args:
|
|
33
33
|
retriever (BaseRetriever): The retriever used to fetch documents.
|
|
34
34
|
build_query (Callable[[Document], ClientBatchCheckItem]): Function to convert documents into FGA queries.
|
|
35
|
-
fga_configuration (
|
|
35
|
+
fga_configuration (ClientConfiguration, optional): Configuration for the OpenFGA client. If not provided, defaults to environment variables.
|
|
36
36
|
"""
|
|
37
37
|
super().__init__()
|
|
38
38
|
self._retriever = retriever
|
|
@@ -95,7 +95,7 @@ class FGARetriever(BaseRetriever):
|
|
|
95
95
|
|
|
96
96
|
Args:
|
|
97
97
|
query (str): The query for retrieving documents.
|
|
98
|
-
run_manager (
|
|
98
|
+
run_manager (object, optional): Optional manager for tracking runs.
|
|
99
99
|
|
|
100
100
|
Returns:
|
|
101
101
|
List[Document]: Filtered and relevant documents.
|
|
@@ -148,7 +148,7 @@ class FGARetriever(BaseRetriever):
|
|
|
148
148
|
|
|
149
149
|
Args:
|
|
150
150
|
query (str): The query for retrieving documents.
|
|
151
|
-
run_manager (
|
|
151
|
+
run_manager (object, optional): Optional manager for tracking runs.
|
|
152
152
|
|
|
153
153
|
Returns:
|
|
154
154
|
List[Document]: Filtered and relevant documents.
|
auth0_ai_langchain/auth0_ai.py
CHANGED
|
@@ -1,43 +1,112 @@
|
|
|
1
1
|
from typing import Callable, Optional
|
|
2
|
-
from langchain_core.runnables.config import RunnableConfig
|
|
3
2
|
from langchain_core.tools import BaseTool
|
|
4
|
-
from auth0_ai.
|
|
5
|
-
from auth0_ai.authorizers.
|
|
6
|
-
from auth0_ai.authorizers.
|
|
7
|
-
from .
|
|
8
|
-
from .
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def get_access_token(config: RunnableConfig) -> Credential:
|
|
12
|
-
"""
|
|
13
|
-
Fetch the access token obtained during the CIBA flow.
|
|
3
|
+
from auth0_ai.authorizers.ciba import CIBAAuthorizerParams
|
|
4
|
+
from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerParams
|
|
5
|
+
from auth0_ai.authorizers.types import Auth0ClientParams
|
|
6
|
+
from auth0_ai_langchain.ciba.ciba_authorizer import CIBAAuthorizer
|
|
7
|
+
from auth0_ai_langchain.federated_connections.federated_connection_authorizer import FederatedConnectionAuthorizer
|
|
8
|
+
|
|
14
9
|
|
|
15
|
-
|
|
16
|
-
|
|
10
|
+
class Auth0AI:
|
|
11
|
+
"""Provides decorators to secure LangChain tools using Auth0 authorization flows.
|
|
17
12
|
"""
|
|
18
|
-
return config.get("configurable", {}).get("_credentials", {}).get("access_token")
|
|
19
13
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
self._graph: Optional[CIBAGraph] = None
|
|
23
|
-
self.config = config
|
|
14
|
+
def __init__(self, auth0: Optional[Auth0ClientParams] = None):
|
|
15
|
+
"""Initializes the Auth0AI instance.
|
|
24
16
|
|
|
25
|
-
|
|
17
|
+
Args:
|
|
18
|
+
auth0 (Optional[Auth0ClientParams]): Parameters for the Auth0 client.
|
|
19
|
+
If not provided, values will be automatically read from environment
|
|
20
|
+
variables: `AUTH0_DOMAIN`, `AUTH0_CLIENT_ID`, and `AUTH0_CLIENT_SECRET`.
|
|
26
21
|
"""
|
|
27
|
-
|
|
22
|
+
self.auth0 = auth0
|
|
28
23
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
24
|
+
def with_async_user_confirmation(self, **params: CIBAAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
|
|
25
|
+
"""Protects a tool with the CIBA (Client-Initiated Backchannel Authentication) flow.
|
|
26
|
+
|
|
27
|
+
Requires user confirmation via a second device (e.g., phone)
|
|
28
|
+
before allowing the tool to execute.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
**params: Parameters defined in `CIBAAuthorizerParams`.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Callable[[BaseTool], BaseTool]: A decorator to wrap a LangChain tool.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
```python
|
|
38
|
+
import os
|
|
39
|
+
from auth0_ai_langchain.auth0_ai import Auth0AI
|
|
40
|
+
from auth0_ai_langchain.ciba import get_ciba_credentials
|
|
41
|
+
from langchain_core.runnables import ensure_config
|
|
42
|
+
from langchain_core.tools import StructuredTool
|
|
43
|
+
|
|
44
|
+
auth0_ai = Auth0AI()
|
|
45
|
+
|
|
46
|
+
with_async_user_confirmation = auth0_ai.with_async_user_confirmation(
|
|
47
|
+
scope="stock:trade",
|
|
48
|
+
audience=os.getenv("AUDIENCE"),
|
|
49
|
+
binding_message=lambda ticker, qty: f"Authorize the purchase of {qty} {ticker}",
|
|
50
|
+
user_id=lambda *_, **__: ensure_config().get("configurable", {}).get("user_id")
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def tool_function(ticker: str, qty: int) -> str:
|
|
54
|
+
credentials = get_ciba_credentials()
|
|
55
|
+
headers = {
|
|
56
|
+
"Authorization": f"{credentials['token_type']} {credentials['access_token']}",
|
|
57
|
+
# ...
|
|
58
|
+
}
|
|
59
|
+
# Call API
|
|
60
|
+
|
|
61
|
+
trade_tool = with_async_user_confirmation(
|
|
62
|
+
StructuredTool(
|
|
63
|
+
name="trade_tool",
|
|
64
|
+
description="Use this function to trade a stock",
|
|
65
|
+
func=tool_function,
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
```
|
|
36
69
|
"""
|
|
37
|
-
|
|
70
|
+
authorizer = CIBAAuthorizer(CIBAAuthorizerParams(**params), self.auth0)
|
|
71
|
+
return authorizer.authorizer()
|
|
72
|
+
|
|
73
|
+
def with_federated_connection(self, **params: FederatedConnectionAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
|
|
74
|
+
"""Enables a tool to obtain an access token from a federated identity provider (e.g., Google, Azure AD).
|
|
75
|
+
|
|
76
|
+
The token can then be used within the tool to call third-party APIs on behalf of the user.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
**params: Parameters defined in `FederatedConnectionAuthorizerParams`.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Callable[[BaseTool], BaseTool]: A decorator to wrap a LangChain tool.
|
|
83
|
+
|
|
84
|
+
Example:
|
|
85
|
+
```python
|
|
86
|
+
from auth0_ai_langchain.auth0_ai import Auth0AI
|
|
87
|
+
from auth0_ai_langchain.federated_connections import get_credentials_for_connection
|
|
88
|
+
from langchain_core.tools import StructuredTool
|
|
89
|
+
from datetime import datetime
|
|
90
|
+
|
|
91
|
+
auth0_ai = Auth0AI()
|
|
92
|
+
|
|
93
|
+
with_google_calendar_access = auth0_ai.with_federated_connection(
|
|
94
|
+
connection="google-oauth2",
|
|
95
|
+
scopes=["https://www.googleapis.com/auth/calendar.freebusy"]
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def tool_function(date: datetime):
|
|
99
|
+
credentials = get_credentials_for_connection()
|
|
100
|
+
# Call Google API using credentials["access_token"]
|
|
38
101
|
|
|
39
|
-
|
|
40
|
-
|
|
102
|
+
check_calendar_tool = with_google_calendar_access(
|
|
103
|
+
StructuredTool(
|
|
104
|
+
name="check_user_calendar",
|
|
105
|
+
description="Use this function to check if the user is available on a certain date and time",
|
|
106
|
+
func=tool_function,
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
```
|
|
41
110
|
"""
|
|
42
|
-
authorizer = FederatedConnectionAuthorizer(FederatedConnectionAuthorizerParams(**
|
|
111
|
+
authorizer = FederatedConnectionAuthorizer(FederatedConnectionAuthorizerParams(**params), self.auth0)
|
|
43
112
|
return authorizer.authorizer()
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import Union
|
|
3
|
+
from auth0_ai.authorizers.ciba import CIBAAuthorizerBase
|
|
4
|
+
from auth0_ai.interrupts.ciba_interrupts import AuthorizationPendingInterrupt, AuthorizationPollingInterrupt
|
|
5
|
+
from auth0_ai_langchain.utils.interrupt import to_graph_interrupt
|
|
6
|
+
from auth0_ai_langchain.utils.tool_wrapper import tool_wrapper
|
|
7
|
+
from langchain_core.tools import BaseTool
|
|
8
|
+
|
|
9
|
+
class CIBAAuthorizer(CIBAAuthorizerBase, ABC):
|
|
10
|
+
def _handle_authorization_interrupts(self, err: Union[AuthorizationPendingInterrupt, AuthorizationPollingInterrupt]) -> None:
|
|
11
|
+
raise to_graph_interrupt(err)
|
|
12
|
+
|
|
13
|
+
def authorizer(self):
|
|
14
|
+
def wrap_tool(tool: BaseTool) -> BaseTool:
|
|
15
|
+
return tool_wrapper(tool, self.protect)
|
|
16
|
+
|
|
17
|
+
return wrap_tool
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from threading import Event
|
|
3
|
+
from typing import Callable, Optional, Dict, Any, List, TypedDict
|
|
4
|
+
from auth0_ai.authorizers.ciba import CIBAAuthorizationRequest
|
|
5
|
+
from auth0_ai.interrupts.ciba_interrupts import CIBAInterrupt, AuthorizationPendingInterrupt, AuthorizationPollingInterrupt
|
|
6
|
+
from auth0_ai_langchain.utils.interrupt import get_auth0_interrupts
|
|
7
|
+
from langgraph_sdk.client import LangGraphClient
|
|
8
|
+
from langgraph_sdk.schema import Thread, Interrupt
|
|
9
|
+
|
|
10
|
+
class WatchedThread(TypedDict):
|
|
11
|
+
thread_id: str
|
|
12
|
+
assistant_id: str
|
|
13
|
+
interruption_id: str
|
|
14
|
+
auth_request: CIBAAuthorizationRequest
|
|
15
|
+
config: Dict[str, Any]
|
|
16
|
+
last_run: float
|
|
17
|
+
|
|
18
|
+
class GraphResumerFilters(TypedDict):
|
|
19
|
+
graph_id: str
|
|
20
|
+
|
|
21
|
+
class GraphResumer:
|
|
22
|
+
def __init__(self, lang_graph: LangGraphClient, filters: Optional[GraphResumerFilters] = None):
|
|
23
|
+
self.lang_graph = lang_graph
|
|
24
|
+
self.filters = filters or {}
|
|
25
|
+
self.map: Dict[str, WatchedThread] = {}
|
|
26
|
+
self._stop_event = Event()
|
|
27
|
+
self._loop_task: Optional[asyncio.Task] = None
|
|
28
|
+
|
|
29
|
+
# Event callbacks
|
|
30
|
+
self._resume_callbacks: List[Callable[[WatchedThread], None]] = []
|
|
31
|
+
self._error_callbacks: List[Callable[[Exception], None]] = []
|
|
32
|
+
|
|
33
|
+
# Public API to register event callbacks
|
|
34
|
+
def on_resume(self, callback: Callable[[WatchedThread], None]) -> "GraphResumer":
|
|
35
|
+
self._resume_callbacks.append(callback)
|
|
36
|
+
return self
|
|
37
|
+
|
|
38
|
+
def on_error(self, callback: Callable[[Exception], None]) -> "GraphResumer":
|
|
39
|
+
self._error_callbacks.append(callback)
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def _emit_resume(self, thread: WatchedThread) -> None:
|
|
43
|
+
for callback in self._resume_callbacks:
|
|
44
|
+
callback(thread)
|
|
45
|
+
|
|
46
|
+
def _emit_error(self, error: Exception) -> None:
|
|
47
|
+
for callback in self._error_callbacks:
|
|
48
|
+
callback(error)
|
|
49
|
+
|
|
50
|
+
async def _get_all_interrupted_threads(self) -> List[Thread]:
|
|
51
|
+
interrupted_threads: List[Thread] = []
|
|
52
|
+
offset = 0
|
|
53
|
+
|
|
54
|
+
while True:
|
|
55
|
+
page = await self.lang_graph.threads.search(
|
|
56
|
+
status="interrupted",
|
|
57
|
+
limit=100,
|
|
58
|
+
offset=offset,
|
|
59
|
+
metadata={"graph_id": self.filters["graph_id"]} if "graph_id" in self.filters else None
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if not page:
|
|
63
|
+
break
|
|
64
|
+
|
|
65
|
+
for t in page:
|
|
66
|
+
interrupt = self._get_first_interrupt(t)
|
|
67
|
+
if interrupt and CIBAInterrupt.is_interrupt(interrupt["value"]) and CIBAInterrupt.has_request_data(interrupt["value"]):
|
|
68
|
+
interrupted_threads.append(t)
|
|
69
|
+
|
|
70
|
+
offset += len(page)
|
|
71
|
+
if len(page) < 100:
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
return interrupted_threads
|
|
75
|
+
|
|
76
|
+
def _get_first_interrupt(self, thread: Thread) -> Optional[Interrupt]:
|
|
77
|
+
interrupts = thread["interrupts"]
|
|
78
|
+
if interrupts:
|
|
79
|
+
values = list(interrupts.values())
|
|
80
|
+
if values and values[0]:
|
|
81
|
+
return values[0][0]
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
def _get_hash_map_id(self, thread: Thread) -> str:
|
|
85
|
+
return f"{thread['thread_id']}:{next(iter(thread['interrupts']))}"
|
|
86
|
+
|
|
87
|
+
async def _resume_thread(self, t: WatchedThread):
|
|
88
|
+
self._emit_resume(t)
|
|
89
|
+
|
|
90
|
+
await self.lang_graph.runs.wait(t["thread_id"], t["assistant_id"], config=t["config"])
|
|
91
|
+
|
|
92
|
+
t["last_run"] = asyncio.get_event_loop().time() * 1000
|
|
93
|
+
|
|
94
|
+
async def loop(self):
|
|
95
|
+
all_threads = await self._get_all_interrupted_threads()
|
|
96
|
+
|
|
97
|
+
# Remove old interrupted threads
|
|
98
|
+
active_keys = {self._get_hash_map_id(t) for t in all_threads}
|
|
99
|
+
|
|
100
|
+
for key in list(self.map.keys()):
|
|
101
|
+
if key not in active_keys:
|
|
102
|
+
del self.map[key]
|
|
103
|
+
|
|
104
|
+
# Add new interrupted threads
|
|
105
|
+
for thread in all_threads:
|
|
106
|
+
interrupt = next(
|
|
107
|
+
(i for i in get_auth0_interrupts(thread)
|
|
108
|
+
if AuthorizationPendingInterrupt.is_interrupt(i["value"])
|
|
109
|
+
or AuthorizationPollingInterrupt.is_interrupt(i["value"])),
|
|
110
|
+
None
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if not interrupt or not interrupt["value"].get("_request"):
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
key = self._get_hash_map_id(thread)
|
|
117
|
+
if key not in self.map:
|
|
118
|
+
self.map[key] = {
|
|
119
|
+
"thread_id": thread["thread_id"],
|
|
120
|
+
"assistant_id": thread["metadata"].get("graph_id"),
|
|
121
|
+
"config": getattr(thread, "config", {}),
|
|
122
|
+
"interruption_id": next(iter(thread["interrupts"])),
|
|
123
|
+
"auth_request": interrupt["value"]["_request"],
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
threads_to_resume = [
|
|
127
|
+
t for t in self.map.values()
|
|
128
|
+
if "last_run" not in t or (t["last_run"] + t["auth_request"]["interval"] * 1000 < asyncio.get_event_loop().time() * 1000)
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
await asyncio.gather(*[
|
|
132
|
+
self._resume_thread(t) for t in threads_to_resume
|
|
133
|
+
])
|
|
134
|
+
|
|
135
|
+
def start(self):
|
|
136
|
+
if self._loop_task and not self._loop_task.done():
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
self._stop_event.clear()
|
|
140
|
+
|
|
141
|
+
async def _run_loop():
|
|
142
|
+
while not self._stop_event.is_set():
|
|
143
|
+
try:
|
|
144
|
+
await self.loop()
|
|
145
|
+
except Exception as e:
|
|
146
|
+
self._emit_error(e)
|
|
147
|
+
await asyncio.sleep(5)
|
|
148
|
+
|
|
149
|
+
self._loop_task = asyncio.create_task(_run_loop())
|
|
150
|
+
|
|
151
|
+
def stop(self):
|
|
152
|
+
self._stop_event.set()
|
|
153
|
+
if self._loop_task:
|
|
154
|
+
self._loop_task.cancel()
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
-
from auth0_ai.interrupts.federated_connection_interrupt import
|
|
2
|
-
|
|
1
|
+
from auth0_ai.interrupts.federated_connection_interrupt import (
|
|
2
|
+
FederatedConnectionError as FederatedConnectionError,
|
|
3
|
+
FederatedConnectionInterrupt as FederatedConnectionInterrupt
|
|
4
|
+
)
|
|
5
|
+
|
|
6
|
+
from auth0_ai.authorizers.federated_connection_authorizer import get_credentials_for_connection as get_credentials_for_connection
|
|
3
7
|
from .federated_connection_authorizer import FederatedConnectionAuthorizer as FederatedConnectionAuthorizer
|
|
@@ -1,52 +1,33 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
from abc import ABC
|
|
3
3
|
from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerBase, FederatedConnectionAuthorizerParams
|
|
4
|
-
from auth0_ai.authorizers.types import
|
|
4
|
+
from auth0_ai.authorizers.types import Auth0ClientParams
|
|
5
5
|
from auth0_ai.interrupts.federated_connection_interrupt import FederatedConnectionInterrupt
|
|
6
|
-
from
|
|
6
|
+
from auth0_ai_langchain.utils.interrupt import to_graph_interrupt
|
|
7
|
+
from auth0_ai_langchain.utils.tool_wrapper import tool_wrapper
|
|
8
|
+
from langchain_core.tools import BaseTool
|
|
7
9
|
from langchain_core.runnables import ensure_config
|
|
8
|
-
from ..utils.interrupt import to_graph_interrupt
|
|
9
10
|
|
|
10
|
-
async def
|
|
11
|
+
async def default_get_refresh_token(*_, **__) -> str | None:
|
|
11
12
|
return ensure_config().get("configurable", {}).get("_credentials", {}).get("refresh_token")
|
|
12
13
|
|
|
13
14
|
class FederatedConnectionAuthorizer(FederatedConnectionAuthorizerBase, ABC):
|
|
14
15
|
def __init__(
|
|
15
16
|
self,
|
|
16
|
-
|
|
17
|
-
|
|
17
|
+
params: FederatedConnectionAuthorizerParams,
|
|
18
|
+
auth0: Auth0ClientParams = None,
|
|
18
19
|
):
|
|
19
|
-
if
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
if params.refresh_token.value is None:
|
|
21
|
+
params = copy.copy(params)
|
|
22
|
+
params.refresh_token.value = default_get_refresh_token
|
|
22
23
|
|
|
23
|
-
super().__init__(
|
|
24
|
+
super().__init__(params, auth0)
|
|
24
25
|
|
|
25
26
|
def _handle_authorization_interrupts(self, err: FederatedConnectionInterrupt) -> None:
|
|
26
27
|
raise to_graph_interrupt(err)
|
|
27
28
|
|
|
28
29
|
def authorizer(self):
|
|
29
|
-
def
|
|
30
|
-
|
|
31
|
-
return await t.ainvoke(input=kwargs)
|
|
32
|
-
|
|
33
|
-
tool_fn = self.protect(
|
|
34
|
-
lambda *_args, **_kwargs: {
|
|
35
|
-
"thread_id": ensure_config().get("configurable", {}).get("thread_id"),
|
|
36
|
-
"checkpoint_ns": ensure_config().get("configurable", {}).get("checkpoint_ns"),
|
|
37
|
-
"run_id": ensure_config().get("configurable", {}).get("run_id"),
|
|
38
|
-
"tool_call_id": ensure_config().get("configurable", {}).get("tool_call_id"), # TODO: review this
|
|
39
|
-
},
|
|
40
|
-
execute_fn
|
|
41
|
-
)
|
|
42
|
-
tool_fn.__name__ = t.name
|
|
43
|
-
|
|
44
|
-
return tool(
|
|
45
|
-
tool_fn,
|
|
46
|
-
description=t.description,
|
|
47
|
-
return_direct=t.return_direct,
|
|
48
|
-
args_schema=t.args_schema,
|
|
49
|
-
response_format=t.response_format,
|
|
50
|
-
)
|
|
30
|
+
def wrap_tool(tool: BaseTool) -> BaseTool:
|
|
31
|
+
return tool_wrapper(tool, self.protect)
|
|
51
32
|
|
|
52
|
-
return
|
|
33
|
+
return wrap_tool
|
|
@@ -1,13 +1,30 @@
|
|
|
1
|
+
from typing import List
|
|
1
2
|
from auth0_ai.interrupts.auth0_interrupt import Auth0Interrupt
|
|
2
3
|
from langgraph.errors import GraphInterrupt
|
|
4
|
+
from langgraph.types import Interrupt
|
|
5
|
+
from langgraph_sdk.schema import Thread
|
|
3
6
|
|
|
4
7
|
|
|
5
8
|
def to_graph_interrupt(interrupt: Auth0Interrupt) -> GraphInterrupt:
|
|
6
9
|
return GraphInterrupt([
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
10
|
+
Interrupt(
|
|
11
|
+
value=interrupt.to_json(),
|
|
12
|
+
when="during",
|
|
13
|
+
resumable=True,
|
|
14
|
+
ns=[f"auth0AI:{interrupt.name}:{interrupt.code}"]
|
|
15
|
+
)
|
|
13
16
|
])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_auth0_interrupts(thread: Thread) -> List[Interrupt]:
|
|
20
|
+
result = []
|
|
21
|
+
|
|
22
|
+
if "interrupts" not in thread:
|
|
23
|
+
return result
|
|
24
|
+
|
|
25
|
+
for interrupt_list in thread["interrupts"].values():
|
|
26
|
+
for interrupt in interrupt_list:
|
|
27
|
+
if Auth0Interrupt.is_interrupt(interrupt["value"]):
|
|
28
|
+
result.append(interrupt)
|
|
29
|
+
|
|
30
|
+
return result
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
from typing_extensions import Annotated
|
|
3
|
+
from pydantic import create_model
|
|
4
|
+
from langchain_core.tools import BaseTool, tool as create_tool, InjectedToolCallId
|
|
5
|
+
from langchain_core.runnables import RunnableConfig
|
|
6
|
+
|
|
7
|
+
def tool_wrapper(tool: BaseTool, protect_fn: Callable) -> BaseTool:
|
|
8
|
+
|
|
9
|
+
# Workaround: extend existing args_schema to be able to get the tool_call_id value
|
|
10
|
+
args_schema = create_model(
|
|
11
|
+
tool.args_schema.__name__ + "Extended",
|
|
12
|
+
__base__=tool.args_schema,
|
|
13
|
+
**{"tool_call_id": (Annotated[str, InjectedToolCallId])}
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
@create_tool(
|
|
17
|
+
tool.name,
|
|
18
|
+
description=tool.description,
|
|
19
|
+
args_schema=args_schema
|
|
20
|
+
)
|
|
21
|
+
async def wrapped_tool(config: RunnableConfig, tool_call_id: Annotated[str, InjectedToolCallId], **input):
|
|
22
|
+
async def execute_fn(*_, **__):
|
|
23
|
+
return await tool.ainvoke(input, config)
|
|
24
|
+
|
|
25
|
+
return await protect_fn(
|
|
26
|
+
lambda *_, **__: {
|
|
27
|
+
"thread_id": config.get("configurable", {}).get("thread_id"),
|
|
28
|
+
"tool_call_id": tool_call_id,
|
|
29
|
+
"tool_name": tool.name,
|
|
30
|
+
},
|
|
31
|
+
execute_fn,
|
|
32
|
+
)(**input)
|
|
33
|
+
|
|
34
|
+
return wrapped_tool
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: auth0-ai-langchain
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0b1
|
|
4
4
|
Summary: This package is an SDK for building secure AI-powered applications using Auth0, Okta FGA and LangChain.
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Auth0
|
|
@@ -11,12 +11,12 @@ Classifier: Programming Language :: Python :: 3
|
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.13
|
|
14
|
-
Requires-Dist: auth0-ai (>=
|
|
15
|
-
Requires-Dist: langchain (>=0.3.
|
|
16
|
-
Requires-Dist: langchain-core (>=0.3.
|
|
17
|
-
Requires-Dist: langgraph (>=0.3
|
|
18
|
-
Requires-Dist: langgraph-sdk (>=0.1.
|
|
19
|
-
Requires-Dist: openfga-sdk (>=0.9.
|
|
14
|
+
Requires-Dist: auth0-ai (>=1.0.0b1,<2.0.0)
|
|
15
|
+
Requires-Dist: langchain (>=0.3.25,<0.4.0)
|
|
16
|
+
Requires-Dist: langchain-core (>=0.3.59,<0.4.0)
|
|
17
|
+
Requires-Dist: langgraph (>=0.4.3,<0.5.0)
|
|
18
|
+
Requires-Dist: langgraph-sdk (>=0.1.66,<0.2.0)
|
|
19
|
+
Requires-Dist: openfga-sdk (>=0.9.4,<0.10.0)
|
|
20
20
|
Project-URL: Homepage, https://auth0.com
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
|
|
@@ -28,13 +28,58 @@ Description-Content-Type: text/markdown
|
|
|
28
28
|
|
|
29
29
|
## Installation
|
|
30
30
|
|
|
31
|
-
>
|
|
32
|
-
> `auth0-ai-langchain` is currently under development and it is not intended to be used in production, and therefore has no official support.
|
|
31
|
+
> ⚠️ **WARNING**: `auth0-ai-langchain` is currently under development and it is not intended to be used in production, and therefore has no official support.
|
|
33
32
|
|
|
34
33
|
```bash
|
|
35
34
|
pip install auth0-ai-langchain
|
|
36
35
|
```
|
|
37
36
|
|
|
37
|
+
## Async User Confirmation
|
|
38
|
+
|
|
39
|
+
`Auth0AI` uses CIBA (Client-Initiated Backchannel Authentication) to handle user confirmation asynchronously. This is useful when you need to confirm a user action before proceeding with a tool execution.
|
|
40
|
+
|
|
41
|
+
Full Example of [Async User Confirmation](https://github.com/auth0-lab/auth0-ai-python/tree/main/examples/async-user-confirmation/langchain-examples).
|
|
42
|
+
|
|
43
|
+
1. Define a tool with the proper authorizer specifying a function to resolve the user id:
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
from auth0_ai_langchain.auth0_ai import Auth0AI
|
|
47
|
+
from auth0_ai_langchain.ciba import get_ciba_credentials
|
|
48
|
+
from langchain_core.runnables import ensure_config
|
|
49
|
+
from langchain_core.tools import StructuredTool
|
|
50
|
+
|
|
51
|
+
# If not provided, Auth0 settings will be read from env variables: `AUTH0_DOMAIN`, `AUTH0_CLIENT_ID`, and `AUTH0_CLIENT_SECRET`
|
|
52
|
+
auth0_ai = Auth0AI()
|
|
53
|
+
|
|
54
|
+
with_async_user_confirmation = auth0_ai.with_async_user_confirmation(
|
|
55
|
+
scope="stock:trade",
|
|
56
|
+
audience=os.getenv("AUDIENCE"),
|
|
57
|
+
binding_message=lambda ticker, qty: f"Authorize the purchase of {qty} {ticker}",
|
|
58
|
+
user_id=lambda *_, **__: ensure_config().get("configurable", {}).get("user_id"),
|
|
59
|
+
# Optional:
|
|
60
|
+
# store=InMemoryStore()
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def tool_function(ticker: str, qty: int) -> str:
|
|
64
|
+
credentials = get_ciba_credentials()
|
|
65
|
+
headers = {
|
|
66
|
+
"Authorization": f"{credentials["token_type"]} {credentials["access_token"]}",
|
|
67
|
+
# ...
|
|
68
|
+
}
|
|
69
|
+
# Call API
|
|
70
|
+
|
|
71
|
+
trade_tool = with_async_user_confirmation(
|
|
72
|
+
StructuredTool(
|
|
73
|
+
name="trade_tool",
|
|
74
|
+
description="Use this function to trade a stock",
|
|
75
|
+
func=trade_tool_function,
|
|
76
|
+
# ...
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
2. Handle interruptions properly. For example, if user is not enrolled to MFA, it will throw an interruption. See [Handling Interrupts](#handling-interrupts) section.
|
|
82
|
+
|
|
38
83
|
## Authorization for Tools
|
|
39
84
|
|
|
40
85
|
The `FGAAuthorizer` can leverage Okta FGA to authorize tools executions. The `FGAAuthorizer.create` function can be used to create an authorizer that checks permissions before executing the tool.
|
|
@@ -44,19 +89,12 @@ Full example of [Authorization for Tools](https://github.com/auth0-lab/auth0-ai-
|
|
|
44
89
|
1. Create an instance of FGA Authorizer:
|
|
45
90
|
|
|
46
91
|
```python
|
|
47
|
-
from auth0_ai_langchain.fga
|
|
92
|
+
from auth0_ai_langchain.fga import FGAAuthorizer
|
|
48
93
|
|
|
94
|
+
# If not provided, FGA settings will be read from env variables: `FGA_STORE_ID`, `FGA_CLIENT_ID`, `FGA_CLIENT_SECRET`, etc.
|
|
49
95
|
fga = FGAAuthorizer.create()
|
|
50
96
|
```
|
|
51
97
|
|
|
52
|
-
**Note**: Here, you can configure and specify your FGA credentials. By `default`, they are read from environment variables:
|
|
53
|
-
|
|
54
|
-
```sh
|
|
55
|
-
FGA_STORE_ID="<fga-store-id>"
|
|
56
|
-
FGA_CLIENT_ID="<fga-client-id>"
|
|
57
|
-
FGA_CLIENT_SECRET="<fga-client-secret>"
|
|
58
|
-
```
|
|
59
|
-
|
|
60
98
|
2. Define the FGA query (`build_query`) and, optionally, the `on_unauthorized` handler:
|
|
61
99
|
|
|
62
100
|
```python
|
|
@@ -74,10 +112,10 @@ async def build_fga_query(tool_input):
|
|
|
74
112
|
def on_unauthorized(tool_input):
|
|
75
113
|
return f"The user is not allowed to buy {tool_input["qty"]} shares of {tool_input["ticker"]}."
|
|
76
114
|
|
|
77
|
-
use_fga = fga(
|
|
115
|
+
use_fga = fga(
|
|
78
116
|
build_query=build_fga_query,
|
|
79
117
|
on_unauthorized=on_unauthorized,
|
|
80
|
-
)
|
|
118
|
+
)
|
|
81
119
|
```
|
|
82
120
|
|
|
83
121
|
**Note**: The parameters given to the `build_query` and `on_unauthorized` functions are the same as those provided to the tool function.
|
|
@@ -103,7 +141,7 @@ buy_tool = StructuredTool(
|
|
|
103
141
|
|
|
104
142
|
## Calling APIs On User's Behalf
|
|
105
143
|
|
|
106
|
-
The `Auth0AI.with_federated_connection` function exchanges user's refresh token taken from the runnable configuration (`config.configurable._credentials.refresh_token`) for a Federated Connection API token.
|
|
144
|
+
The `Auth0AI.with_federated_connection` function exchanges user's refresh token taken, by default, from the runnable configuration (`config.configurable._credentials.refresh_token`) for a Federated Connection API token.
|
|
107
145
|
|
|
108
146
|
Full Example of [Calling APIs On User's Behalf](https://github.com/auth0-lab/auth0-ai-python/tree/main/examples/calling-apis/langchain-examples).
|
|
109
147
|
|
|
@@ -111,19 +149,23 @@ Full Example of [Calling APIs On User's Behalf](https://github.com/auth0-lab/aut
|
|
|
111
149
|
|
|
112
150
|
```python
|
|
113
151
|
from auth0_ai_langchain.auth0_ai import Auth0AI
|
|
114
|
-
from auth0_ai_langchain.federated_connections import
|
|
152
|
+
from auth0_ai_langchain.federated_connections import get_credentials_for_connection
|
|
115
153
|
from langchain_core.tools import StructuredTool
|
|
116
154
|
|
|
155
|
+
# If not provided, Auth0 settings will be read from env variables: `AUTH0_DOMAIN`, `AUTH0_CLIENT_ID`, and `AUTH0_CLIENT_SECRET`
|
|
117
156
|
auth0_ai = Auth0AI()
|
|
118
157
|
|
|
119
158
|
with_google_calendar_access = auth0_ai.with_federated_connection(
|
|
120
159
|
connection="google-oauth2",
|
|
121
|
-
scopes=["https://www.googleapis.com/auth/calendar.freebusy"]
|
|
160
|
+
scopes=["https://www.googleapis.com/auth/calendar.freebusy"],
|
|
161
|
+
# Optional:
|
|
162
|
+
# refresh_token=lambda *_, **__: ensure_config().get("configurable", {}).get("_credentials", {}).get("refresh_token"),
|
|
163
|
+
# store=InMemoryStore(),
|
|
122
164
|
)
|
|
123
165
|
|
|
124
166
|
def tool_function(date: datetime):
|
|
125
|
-
|
|
126
|
-
# Call Google API
|
|
167
|
+
credentials = get_credentials_for_connection()
|
|
168
|
+
# Call Google API using credentials["access_token"]
|
|
127
169
|
|
|
128
170
|
check_calendar_tool = with_google_calendar_access(
|
|
129
171
|
StructuredTool(
|
|
@@ -155,7 +197,7 @@ workflow = (
|
|
|
155
197
|
)
|
|
156
198
|
```
|
|
157
199
|
|
|
158
|
-
3. Handle interruptions properly.
|
|
200
|
+
3. Handle interruptions properly. For example, if the tool does not have access to user's Google Calendar, it will throw an interruption. See [Handling Interrupts](#handling-interrupts) section.
|
|
159
201
|
|
|
160
202
|
## RAG with FGA
|
|
161
203
|
|
|
@@ -185,7 +227,8 @@ vector_store = VectorStoreIndex.from_documents(documents)
|
|
|
185
227
|
# Create a retriever:
|
|
186
228
|
base_retriever = vector_store.as_retriever()
|
|
187
229
|
|
|
188
|
-
# Create the FGA retriever wrapper
|
|
230
|
+
# Create the FGA retriever wrapper.
|
|
231
|
+
# If not provided, FGA settings will be read from env variables: `FGA_STORE_ID`, `FGA_CLIENT_ID`, `FGA_CLIENT_SECRET`, etc.
|
|
189
232
|
retriever = FGARetriever(
|
|
190
233
|
base_retriever,
|
|
191
234
|
build_query=lambda node: ClientCheckRequest(
|
|
@@ -207,6 +250,63 @@ response = query_engine.query("What is the forecast for ZEKO?")
|
|
|
207
250
|
print(response)
|
|
208
251
|
```
|
|
209
252
|
|
|
253
|
+
## Handling Interrupts
|
|
254
|
+
|
|
255
|
+
`Auth0AI` uses interrupts extensively and will never block a graph. Whenever an authorizer requires user interaction, the graph throws a `GraphInterrupt` exception with data that allows the client to resume the flow.
|
|
256
|
+
|
|
257
|
+
It is important to disable error handling in your tools node as follows:
|
|
258
|
+
|
|
259
|
+
```python
|
|
260
|
+
.add_node(
|
|
261
|
+
"tools",
|
|
262
|
+
ToolNode(
|
|
263
|
+
[
|
|
264
|
+
# your authorizer-wrapped tools
|
|
265
|
+
],
|
|
266
|
+
# Error handler should be disabled in order to trigger interruptions from within tools.
|
|
267
|
+
handle_tool_errors=False
|
|
268
|
+
)
|
|
269
|
+
)
|
|
270
|
+
```
|
|
271
|
+
|
|
272
|
+
From the client side of the graph you get the interrupts:
|
|
273
|
+
|
|
274
|
+
```python
|
|
275
|
+
from auth0_ai_langchain.utils.interrupt import get_auth0_interrupts
|
|
276
|
+
|
|
277
|
+
# Get the langgraph thread:
|
|
278
|
+
thread = await client.threads.get(thread_id)
|
|
279
|
+
|
|
280
|
+
# Filter the auth0 interrupts:
|
|
281
|
+
auth0_interrupts = get_auth0_interrupts(thread)
|
|
282
|
+
```
|
|
283
|
+
|
|
284
|
+
Then you can resume the thread by doing this:
|
|
285
|
+
|
|
286
|
+
```python
|
|
287
|
+
await client.runs.wait(thread_id, assistant_id)
|
|
288
|
+
```
|
|
289
|
+
|
|
290
|
+
For the specific case of **CIBA (Client-Initiated Backchannel Authorization)** you might attach a `GraphResumer` instance that watches for interrupted threads in the `"Authorization Pending"` state and attempts to resume them automatically, respecting Auth0's polling interval.
|
|
291
|
+
|
|
292
|
+
```python
|
|
293
|
+
import os
|
|
294
|
+
from auth0_ai_langchain.ciba import GraphResumer
|
|
295
|
+
from langgraph_sdk import get_client
|
|
296
|
+
|
|
297
|
+
resumer = GraphResumer(
|
|
298
|
+
lang_graph=get_client(url=os.getenv("LANGGRAPH_API_URL")),
|
|
299
|
+
# optionally, you can filter by a specific graph:
|
|
300
|
+
filters={"graph_id": "conditional-trade"},
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
resumer \
|
|
304
|
+
.on_resume(lambda thread: print(f"Attempting to resume thread {thread['thread_id']} from interruption {thread['interruption_id']}")) \
|
|
305
|
+
.on_error(lambda err: print(f"Error in GraphResumer: {str(err)}"))
|
|
306
|
+
|
|
307
|
+
resumer.start()
|
|
308
|
+
```
|
|
309
|
+
|
|
210
310
|
---
|
|
211
311
|
|
|
212
312
|
<p align="center">
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
auth0_ai_langchain/FGARetriever.py,sha256=SQwxo2aDtQhwQtYmszoKw3BH-U5QVnvPAgVw9EDzKVM,6002
|
|
2
|
+
auth0_ai_langchain/__init__.py,sha256=I331Kz-q97ZU7TfXaOR5UBbJamGEJ15twbf2HP1iCHs,67
|
|
3
|
+
auth0_ai_langchain/auth0_ai.py,sha256=J3fxYNZf0KMK2w085dCdGfCRyafQGWPAI19edcYpQi8,4732
|
|
4
|
+
auth0_ai_langchain/ciba/__init__.py,sha256=X62HZB20XdhsgcaKld6rLm2BOSuiO5uU5v7ePQz27Mk,268
|
|
5
|
+
auth0_ai_langchain/ciba/ciba_authorizer.py,sha256=GRAB3NBnmoxAECrRjPNdA9N9uQ4pCEzP6dF8RUwlysM,766
|
|
6
|
+
auth0_ai_langchain/ciba/graph_resumer.py,sha256=EpdzzB_NccdggKA3x__Q3Yziejo7AJjR4aJ57TZmYPA,5474
|
|
7
|
+
auth0_ai_langchain/federated_connections/__init__.py,sha256=kGpPN9ntsyvE-2m_lcdCVvPevadyILCk3NiAm4TN0QA,429
|
|
8
|
+
auth0_ai_langchain/federated_connections/federated_connection_authorizer.py,sha256=o25oRGiTo9y5mpjDNEWWaFVAIFbhwaxC0pcRId-4oYE,1405
|
|
9
|
+
auth0_ai_langchain/fga/__init__.py,sha256=rgqTD4Gvz28jNdqhxTG5udbgyeUMsyvRj83fHBJdt4s,137
|
|
10
|
+
auth0_ai_langchain/utils/interrupt.py,sha256=DZ1b9OAkg3SQru9mSaQGBC6UY0ODz7QSskS9RlVyEGw,860
|
|
11
|
+
auth0_ai_langchain/utils/tool_wrapper.py,sha256=dHjcqykT2aohdFOm0mLZ9U6bXB6NHjfABb3aXef5174,1210
|
|
12
|
+
auth0_ai_langchain-1.0.0b1.dist-info/LICENSE,sha256=Lu_2YH0oK8b_VVisAhNQ2WIdtwY8pSU2PLbll-y6Cj8,9792
|
|
13
|
+
auth0_ai_langchain-1.0.0b1.dist-info/METADATA,sha256=lyc_GI9ymhgIrQh2wM2fv-lYF_uWT9VXFanst8HowEs,11790
|
|
14
|
+
auth0_ai_langchain-1.0.0b1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
|
15
|
+
auth0_ai_langchain-1.0.0b1.dist-info/RECORD,,
|
|
@@ -1,109 +0,0 @@
|
|
|
1
|
-
from typing import Awaitable, Hashable, List, Optional, Callable, Any, Union
|
|
2
|
-
from langchain_core.tools import StructuredTool
|
|
3
|
-
from langchain_core.tools.base import BaseTool
|
|
4
|
-
from langgraph.graph import StateGraph, END, START
|
|
5
|
-
from langchain_core.runnables import Runnable
|
|
6
|
-
from auth0_ai.authorizers.types import AuthorizerParams
|
|
7
|
-
from ..types import Auth0Nodes
|
|
8
|
-
from .initialize_ciba import initialize_ciba
|
|
9
|
-
from .initialize_hitl import initialize_hitl
|
|
10
|
-
from .types import CIBAGraphOptions, CIBAOptions, ProtectedTool, BaseState
|
|
11
|
-
|
|
12
|
-
class CIBAGraph():
|
|
13
|
-
def __init__(
|
|
14
|
-
self,
|
|
15
|
-
options: Optional[CIBAGraphOptions] = None,
|
|
16
|
-
authorizer_params: Optional[AuthorizerParams] = None,
|
|
17
|
-
):
|
|
18
|
-
self.options = options
|
|
19
|
-
self.authorizer_params = authorizer_params
|
|
20
|
-
self.tools: List[ProtectedTool] = []
|
|
21
|
-
self.graph: Optional[StateGraph] = None
|
|
22
|
-
|
|
23
|
-
def get_tools(self) -> List[ProtectedTool]:
|
|
24
|
-
return self.tools
|
|
25
|
-
|
|
26
|
-
def get_graph(self) -> Optional[StateGraph]:
|
|
27
|
-
return self.graph
|
|
28
|
-
|
|
29
|
-
def get_options(self) -> Optional[CIBAGraphOptions]:
|
|
30
|
-
return self.options
|
|
31
|
-
|
|
32
|
-
def get_authorizer_params(self) -> Optional[AuthorizerParams]:
|
|
33
|
-
return self.authorizer_params
|
|
34
|
-
|
|
35
|
-
def register_nodes(
|
|
36
|
-
self,
|
|
37
|
-
graph: StateGraph,
|
|
38
|
-
) -> StateGraph:
|
|
39
|
-
self.graph = graph
|
|
40
|
-
|
|
41
|
-
# Add CIBA HITL and CIBA nodes
|
|
42
|
-
self.graph.add_node(Auth0Nodes.AUTH0_CIBA_HITL.value, initialize_hitl(self))
|
|
43
|
-
self.graph.add_node(Auth0Nodes.AUTH0_CIBA.value, initialize_ciba(self))
|
|
44
|
-
self.graph.add_conditional_edges(
|
|
45
|
-
Auth0Nodes.AUTH0_CIBA.value,
|
|
46
|
-
lambda state: END if getattr(state, "auth0", {}).get("error") else Auth0Nodes.AUTH0_CIBA_HITL.value,
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
return graph
|
|
50
|
-
|
|
51
|
-
def protect_tool(
|
|
52
|
-
self,
|
|
53
|
-
tool: Union[BaseTool, Callable],
|
|
54
|
-
options: CIBAOptions,
|
|
55
|
-
) -> StructuredTool:
|
|
56
|
-
"""
|
|
57
|
-
Authorize Options to start CIBA flow.
|
|
58
|
-
|
|
59
|
-
Attributes:
|
|
60
|
-
tool (Union[BaseTool, Callable]): The tool to be protected.
|
|
61
|
-
options (CIBAOptions): The CIBA options.
|
|
62
|
-
"""
|
|
63
|
-
|
|
64
|
-
# Merge default options with tool-specific options
|
|
65
|
-
merged_options = {**self.options, **options.__dict__} if isinstance(self.options, dict) else {**vars(self.options), **vars(options)}
|
|
66
|
-
|
|
67
|
-
if merged_options["on_approve_go_to"] is None:
|
|
68
|
-
raise ValueError(f"[{tool.name}] on_approve_go_to is required")
|
|
69
|
-
|
|
70
|
-
if merged_options["on_reject_go_to"] is None:
|
|
71
|
-
raise ValueError(f"[{tool.name}] on_reject_go_to is required")
|
|
72
|
-
|
|
73
|
-
self.tools.append(ProtectedTool(tool_name=tool.name, options=merged_options))
|
|
74
|
-
|
|
75
|
-
return tool
|
|
76
|
-
|
|
77
|
-
def with_auth(self, path: Union[
|
|
78
|
-
Callable[..., Union[Hashable, list[Hashable]]],
|
|
79
|
-
Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
|
|
80
|
-
Runnable[Any, Union[Hashable, list[Hashable]]],
|
|
81
|
-
]):
|
|
82
|
-
"""
|
|
83
|
-
A wrapper for the callable that determines the next node or nodes using a protected tool.
|
|
84
|
-
|
|
85
|
-
Attributes:
|
|
86
|
-
path (Union[Callable[..., Union[Hashable, list[Hashable]]], Callable[..., Awaitable[Union[Hashable, list[Hashable]]]], Runnable[Any, Union[Hashable, list[Hashable]]]])): The callable that determines the next node or nodes using a protected tool.
|
|
87
|
-
"""
|
|
88
|
-
def wrapper(*args):
|
|
89
|
-
if not callable(path):
|
|
90
|
-
return START
|
|
91
|
-
|
|
92
|
-
state: BaseState = args[0]
|
|
93
|
-
messages = state.get("messages")
|
|
94
|
-
last_message = messages[-1] if messages else None
|
|
95
|
-
|
|
96
|
-
# Call default path if there are no tool calls
|
|
97
|
-
if not last_message or not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
|
98
|
-
return path(*args)
|
|
99
|
-
|
|
100
|
-
tool_name = last_message.tool_calls[0]["name"]
|
|
101
|
-
tool = next((t for t in self.tools if t.tool_name == tool_name), None)
|
|
102
|
-
|
|
103
|
-
if tool:
|
|
104
|
-
return Auth0Nodes.AUTH0_CIBA.value
|
|
105
|
-
|
|
106
|
-
# Call default path if tool is not protected
|
|
107
|
-
return path(*args)
|
|
108
|
-
|
|
109
|
-
return wrapper
|
|
@@ -1,91 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from langgraph.types import Command
|
|
3
|
-
from langgraph_sdk import get_client
|
|
4
|
-
from langchain_core.runnables.config import RunnableConfig
|
|
5
|
-
from auth0_ai.authorizers.ciba_authorizer import CIBAAuthorizer
|
|
6
|
-
from ..types import Auth0Graphs, Auth0Nodes
|
|
7
|
-
from .types import ICIBAGraph, BaseState
|
|
8
|
-
from .utils import get_tool_definition
|
|
9
|
-
|
|
10
|
-
def initialize_ciba(ciba_graph: ICIBAGraph):
|
|
11
|
-
async def handler(state: BaseState, config: RunnableConfig):
|
|
12
|
-
try:
|
|
13
|
-
ciba_params = ciba_graph.get_options()
|
|
14
|
-
tools = ciba_graph.get_tools()
|
|
15
|
-
tool_definition = get_tool_definition(state, tools)
|
|
16
|
-
|
|
17
|
-
if not tool_definition:
|
|
18
|
-
return Command(resume=True)
|
|
19
|
-
|
|
20
|
-
graph = ciba_graph.get_graph()
|
|
21
|
-
metadata, tool = tool_definition["metadata"], tool_definition["tool"]
|
|
22
|
-
ciba_options = metadata.options
|
|
23
|
-
|
|
24
|
-
langgraph = get_client(url=os.getenv("LANGGRAPH_API_URL", "http://localhost:54367"))
|
|
25
|
-
|
|
26
|
-
# Check if CIBA Poller Graph exists
|
|
27
|
-
search_result = await langgraph.assistants.search(graph_id=Auth0Graphs.CIBA_POLLER.value)
|
|
28
|
-
if not search_result:
|
|
29
|
-
raise ValueError(
|
|
30
|
-
f"[{Auth0Nodes.AUTH0_CIBA}] \"{Auth0Graphs.CIBA_POLLER}\" does not exist. Make sure to register the graph in your \"langgraph.json\"."
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
if ciba_options["on_approve_go_to"] not in graph.nodes:
|
|
34
|
-
raise ValueError(f"[{Auth0Nodes.AUTH0_CIBA}] \"{ciba_options["on_approve_go_to"]}\" is not a valid node.")
|
|
35
|
-
|
|
36
|
-
if ciba_options["on_reject_go_to"] not in graph.nodes:
|
|
37
|
-
raise ValueError(f"[{Auth0Nodes.AUTH0_CIBA}] \"{ciba_options["on_reject_go_to"]}\" is not a valid node.")
|
|
38
|
-
|
|
39
|
-
scheduler = ciba_params.config["scheduler"]
|
|
40
|
-
on_resume_invoke = ciba_params.config["on_resume_invoke"]
|
|
41
|
-
audience = ciba_params.audience
|
|
42
|
-
|
|
43
|
-
if not scheduler:
|
|
44
|
-
raise ValueError(f"[{Auth0Nodes.AUTH0_CIBA}] \"scheduler\" must be a \"function\" or a \"string\".")
|
|
45
|
-
|
|
46
|
-
if not on_resume_invoke:
|
|
47
|
-
raise ValueError(f"[{Auth0Nodes.AUTH0_CIBA}] \"on_resume_invoke\" must be defined.")
|
|
48
|
-
|
|
49
|
-
user_id = config.get("configurable", {}).get("user_id")
|
|
50
|
-
thread_id = config.get("metadata", {}).get("thread_id")
|
|
51
|
-
|
|
52
|
-
ciba_response = await CIBAAuthorizer.start(
|
|
53
|
-
{
|
|
54
|
-
"user_id": user_id,
|
|
55
|
-
"scope": ciba_options["scope"] or "openid",
|
|
56
|
-
"audience": audience,
|
|
57
|
-
"binding_message": ciba_options["binding_message"],
|
|
58
|
-
},
|
|
59
|
-
ciba_graph.get_authorizer_params(),
|
|
60
|
-
tool["args"],
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
scheduler_params = {
|
|
64
|
-
"tool_id": tool["id"],
|
|
65
|
-
"user_id": user_id,
|
|
66
|
-
"ciba_graph_id": Auth0Graphs.CIBA_POLLER.value,
|
|
67
|
-
"thread_id": thread_id,
|
|
68
|
-
"ciba_response": ciba_response,
|
|
69
|
-
"on_resume_invoke": on_resume_invoke,
|
|
70
|
-
}
|
|
71
|
-
|
|
72
|
-
if callable(scheduler):
|
|
73
|
-
# Use Custom Scheduler
|
|
74
|
-
await scheduler(scheduler_params)
|
|
75
|
-
elif isinstance(scheduler, str):
|
|
76
|
-
# Use Langgraph SDK to schedule the task
|
|
77
|
-
await langgraph.crons.create_for_thread(
|
|
78
|
-
thread_id,
|
|
79
|
-
scheduler_params["ciba_graph_id"],
|
|
80
|
-
schedule="*/1 * * * *", # Default to every minute
|
|
81
|
-
input=scheduler_params,
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
print("CIBA Task Scheduled")
|
|
85
|
-
except Exception as e:
|
|
86
|
-
print(e)
|
|
87
|
-
state["auth0"] = {"error": str(e)}
|
|
88
|
-
|
|
89
|
-
return state
|
|
90
|
-
|
|
91
|
-
return handler
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
from typing import Awaitable, Callable
|
|
2
|
-
from langchain_core.messages import ToolMessage, AIMessage, ToolCall
|
|
3
|
-
from langgraph.types import interrupt, Command
|
|
4
|
-
from .types import ICIBAGraph, BaseState
|
|
5
|
-
from .utils import get_tool_definition
|
|
6
|
-
from auth0_ai.authorizers.ciba_authorizer import CibaAuthorizerCheckResponse
|
|
7
|
-
|
|
8
|
-
def initialize_hitl(ciba_graph: ICIBAGraph) -> Callable[[BaseState], Awaitable[Command]]:
|
|
9
|
-
async def handler(state: BaseState) -> Command:
|
|
10
|
-
tools = ciba_graph.get_tools()
|
|
11
|
-
tool_definition = get_tool_definition(state, tools)
|
|
12
|
-
|
|
13
|
-
# if no tool calls, resume
|
|
14
|
-
if not tool_definition:
|
|
15
|
-
return Command(resume=True)
|
|
16
|
-
|
|
17
|
-
# wait for user approval
|
|
18
|
-
human_review = interrupt("A push notification has been sent to your device.")
|
|
19
|
-
|
|
20
|
-
metadata, tool, message = tool_definition["metadata"], tool_definition["tool"], tool_definition["message"]
|
|
21
|
-
|
|
22
|
-
if human_review["status"] == CibaAuthorizerCheckResponse.APPROVED.value:
|
|
23
|
-
updated_message = AIMessage(
|
|
24
|
-
id=message.id,
|
|
25
|
-
content="The user has approved the transaction",
|
|
26
|
-
tool_calls=[
|
|
27
|
-
ToolCall(
|
|
28
|
-
name=tool["name"],
|
|
29
|
-
args=tool["args"],
|
|
30
|
-
id=tool["id"],
|
|
31
|
-
)
|
|
32
|
-
],
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
return Command(
|
|
36
|
-
goto=metadata.options["on_approve_go_to"],
|
|
37
|
-
update={"messages": [updated_message]},
|
|
38
|
-
)
|
|
39
|
-
else:
|
|
40
|
-
tool_message = ToolMessage(
|
|
41
|
-
name=tool["name"],
|
|
42
|
-
content="The user has rejected the transaction.",
|
|
43
|
-
tool_call_id=tool["id"],
|
|
44
|
-
)
|
|
45
|
-
return Command(
|
|
46
|
-
goto=metadata.options["on_reject_go_to"],
|
|
47
|
-
update={"messages": [tool_message]},
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
return handler
|
|
@@ -1,115 +0,0 @@
|
|
|
1
|
-
from typing import Optional, List, Callable, Union, Awaitable, TypedDict
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
3
|
-
from langgraph.graph import StateGraph
|
|
4
|
-
from langchain_core.messages import AIMessage, ToolMessage
|
|
5
|
-
from auth0_ai.authorizers.types import AuthorizerParams
|
|
6
|
-
from auth0_ai.authorizers.ciba_authorizer import AuthorizeResponse
|
|
7
|
-
|
|
8
|
-
class Auth0State(TypedDict):
|
|
9
|
-
error: str
|
|
10
|
-
|
|
11
|
-
class BaseState(TypedDict):
|
|
12
|
-
task_id: str
|
|
13
|
-
messages: List[Union[AIMessage, ToolMessage]]
|
|
14
|
-
auth0: Optional[Auth0State] = None
|
|
15
|
-
|
|
16
|
-
class SchedulerParams:
|
|
17
|
-
def __init__(
|
|
18
|
-
self,
|
|
19
|
-
user_id: str,
|
|
20
|
-
thread_id: str,
|
|
21
|
-
ciba_graph_id: str,
|
|
22
|
-
ciba_response: AuthorizeResponse,
|
|
23
|
-
tool_id: Optional[str] = None,
|
|
24
|
-
on_resume_invoke: str = "",
|
|
25
|
-
):
|
|
26
|
-
self.user_id = user_id
|
|
27
|
-
self.thread_id = thread_id
|
|
28
|
-
self.tool_id = tool_id
|
|
29
|
-
self.on_resume_invoke = on_resume_invoke
|
|
30
|
-
self.ciba_graph_id = ciba_graph_id
|
|
31
|
-
self.ciba_response = ciba_response
|
|
32
|
-
|
|
33
|
-
class CIBAOptions():
|
|
34
|
-
"""
|
|
35
|
-
The CIBA options.
|
|
36
|
-
|
|
37
|
-
Attributes:
|
|
38
|
-
binding_message (Union[str, Callable[..., Awaitable[str]]]): A human-readable string to display to the user, or a function that resolves it.
|
|
39
|
-
scope (Optional[str]): Space-separated list of OIDC and custom API scopes.
|
|
40
|
-
on_approve_go_to (Optional[str]): A node name to redirect the flow after user approval.
|
|
41
|
-
on_reject_go_to (Optional[str]): A node name to redirect the flow after user rejection.
|
|
42
|
-
audience (Optional[str]): Unique identifier of the audience for an issued token.
|
|
43
|
-
request_expiry (Optional[int]): To configure a custom expiry time in seconds for CIBA request, pass a number between 1 and 300.
|
|
44
|
-
"""
|
|
45
|
-
def __init__(
|
|
46
|
-
self,
|
|
47
|
-
binding_message: Union[str, Callable[..., Awaitable[str]]],
|
|
48
|
-
scope: Optional[str] = None,
|
|
49
|
-
on_approve_go_to: Optional[str] = None,
|
|
50
|
-
on_reject_go_to: Optional[str] = None,
|
|
51
|
-
audience: Optional[str] = None,
|
|
52
|
-
request_expiry: Optional[int] = None,
|
|
53
|
-
):
|
|
54
|
-
self.binding_message = binding_message
|
|
55
|
-
self.scope = scope
|
|
56
|
-
self.on_approve_go_to = on_approve_go_to
|
|
57
|
-
self.on_reject_go_to = on_reject_go_to
|
|
58
|
-
self.audience = audience
|
|
59
|
-
self.request_expiry = request_expiry
|
|
60
|
-
|
|
61
|
-
class ProtectedTool():
|
|
62
|
-
def __init__(self, tool_name: str, options: CIBAOptions):
|
|
63
|
-
self.tool_name = tool_name
|
|
64
|
-
self.options = options
|
|
65
|
-
|
|
66
|
-
class CIBAGraphOptionsConfig:
|
|
67
|
-
def __init__(self, on_resume_invoke: str, scheduler: Union[str, Callable[[SchedulerParams], Awaitable[None]]]):
|
|
68
|
-
self.on_resume_invoke = on_resume_invoke
|
|
69
|
-
self.scheduler = scheduler
|
|
70
|
-
|
|
71
|
-
class CIBAGraphOptions():
|
|
72
|
-
"""
|
|
73
|
-
The base CIBA options.
|
|
74
|
-
|
|
75
|
-
Attributes:
|
|
76
|
-
config (CIBAGraphOptionsConfig): Configuration options.
|
|
77
|
-
scope (Optional[str]): Space-separated list of OIDC and custom API scopes.
|
|
78
|
-
on_approve_go_to (Optional[str]): A node name to redirect the flow after user approval.
|
|
79
|
-
on_reject_go_to (Optional[str]): A node name to redirect the flow after user rejection.
|
|
80
|
-
audience (Optional[str]): Unique identifier of the audience for an issued token.
|
|
81
|
-
request_expiry (Optional[int]): To configure a custom expiry time in seconds for CIBA request, pass a number between 1 and 300.
|
|
82
|
-
"""
|
|
83
|
-
def __init__(
|
|
84
|
-
self,
|
|
85
|
-
config: CIBAGraphOptionsConfig,
|
|
86
|
-
scope: Optional[str] = None,
|
|
87
|
-
on_approve_go_to: Optional[str] = None,
|
|
88
|
-
on_reject_go_to: Optional[str] = None,
|
|
89
|
-
audience: Optional[str] = None,
|
|
90
|
-
request_expiry: Optional[int] = None,
|
|
91
|
-
|
|
92
|
-
):
|
|
93
|
-
self.config = config
|
|
94
|
-
self.scope = scope
|
|
95
|
-
self.on_approve_go_to = on_approve_go_to
|
|
96
|
-
self.on_reject_go_to = on_reject_go_to
|
|
97
|
-
self.audience = audience
|
|
98
|
-
self.request_expiry = request_expiry
|
|
99
|
-
|
|
100
|
-
class ICIBAGraph(ABC):
|
|
101
|
-
@abstractmethod
|
|
102
|
-
def get_tools(self) -> List[ProtectedTool]:
|
|
103
|
-
pass
|
|
104
|
-
|
|
105
|
-
@abstractmethod
|
|
106
|
-
def get_graph(self) -> StateGraph:
|
|
107
|
-
pass
|
|
108
|
-
|
|
109
|
-
@abstractmethod
|
|
110
|
-
def get_authorizer_params(self) -> Optional[AuthorizerParams]:
|
|
111
|
-
pass
|
|
112
|
-
|
|
113
|
-
@abstractmethod
|
|
114
|
-
def get_options(self) -> Optional[CIBAGraphOptions]:
|
|
115
|
-
pass
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
from typing import Optional, List
|
|
2
|
-
from .types import ProtectedTool, BaseState
|
|
3
|
-
|
|
4
|
-
def get_tool_definition(state: BaseState, tools: List[ProtectedTool]) -> Optional[dict]:
|
|
5
|
-
message = state["messages"][-1]
|
|
6
|
-
|
|
7
|
-
if not hasattr(message, "tool_calls") or not message.tool_calls:
|
|
8
|
-
return None
|
|
9
|
-
|
|
10
|
-
tool_calls = message.tool_calls
|
|
11
|
-
tool = tool_calls[-1]
|
|
12
|
-
metadata = next((t for t in tools if t.tool_name == tool["name"]), None)
|
|
13
|
-
|
|
14
|
-
if not metadata:
|
|
15
|
-
return None
|
|
16
|
-
|
|
17
|
-
return {"metadata": metadata, "tool": tool, "message": message}
|
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import Awaitable, Callable, Optional, TypedDict, Union
|
|
3
|
-
|
|
4
|
-
from auth0_ai.authorizers.ciba_authorizer import (
|
|
5
|
-
AuthorizeResponse,
|
|
6
|
-
CIBAAuthorizer,
|
|
7
|
-
CibaAuthorizerCheckResponse,
|
|
8
|
-
)
|
|
9
|
-
from auth0_ai.credentials import Credentials
|
|
10
|
-
from auth0_ai.token_response import TokenResponse
|
|
11
|
-
from langgraph.graph import END, START, StateGraph
|
|
12
|
-
from langgraph_sdk import get_client
|
|
13
|
-
from langgraph_sdk.schema import Command
|
|
14
|
-
|
|
15
|
-
from auth0_ai_langchain.ciba.types import Auth0Graphs
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class State(TypedDict):
|
|
19
|
-
ciba_response: AuthorizeResponse
|
|
20
|
-
on_resume_invoke: str
|
|
21
|
-
thread_id: str
|
|
22
|
-
user_id: str
|
|
23
|
-
|
|
24
|
-
# Internal
|
|
25
|
-
task_id: str
|
|
26
|
-
tool_id: str
|
|
27
|
-
status: CibaAuthorizerCheckResponse
|
|
28
|
-
token_response: Optional[TokenResponse]
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def ciba_poller_graph(on_stop_scheduler: Union[str, Callable[[State], Awaitable[None]]]):
|
|
32
|
-
"""
|
|
33
|
-
A LangGraph graph to monitor the status of a CIBA transaction.
|
|
34
|
-
|
|
35
|
-
Attributes:
|
|
36
|
-
on_stop_scheduler (Union[str, Callable[[State], Awaitable[None]]]): A graph name to redirect the flow, or a function to execute when the CIBA transaction expires.
|
|
37
|
-
"""
|
|
38
|
-
async def check_status(state: State):
|
|
39
|
-
try:
|
|
40
|
-
res = await CIBAAuthorizer.check(state["ciba_response"]["auth_req_id"])
|
|
41
|
-
state["token_response"] = res.get("token")
|
|
42
|
-
state["status"] = res.get("status")
|
|
43
|
-
except Exception as e:
|
|
44
|
-
print(f"Error in check_status: {e}")
|
|
45
|
-
return state
|
|
46
|
-
|
|
47
|
-
async def stop_scheduler(state: State):
|
|
48
|
-
try:
|
|
49
|
-
if isinstance(on_stop_scheduler, str):
|
|
50
|
-
langgraph = get_client(url=os.getenv(
|
|
51
|
-
"LANGGRAPH_API_URL", "http://localhost:54367"))
|
|
52
|
-
await langgraph.crons.create_for_thread(state.thread_id, Auth0Graphs.CIBA_POLLER.value)
|
|
53
|
-
elif callable(on_stop_scheduler):
|
|
54
|
-
await on_stop_scheduler(state)
|
|
55
|
-
except Exception as e:
|
|
56
|
-
print(f"Error in stop_scheduler: {e}")
|
|
57
|
-
return state
|
|
58
|
-
|
|
59
|
-
async def resume_agent(state: State):
|
|
60
|
-
langgraph = get_client(url=os.getenv(
|
|
61
|
-
"LANGGRAPH_API_URL", "http://localhost:54367"))
|
|
62
|
-
_credentials: Credentials = None
|
|
63
|
-
|
|
64
|
-
try:
|
|
65
|
-
if state["status"] == CibaAuthorizerCheckResponse.APPROVED:
|
|
66
|
-
_credentials = {
|
|
67
|
-
"access_token": {
|
|
68
|
-
"type": state["token_response"].get("token_type", "Bearer"),
|
|
69
|
-
"value": state["token_response"].get("access_token"),
|
|
70
|
-
}
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
await langgraph.runs.wait(
|
|
74
|
-
state["thread_id"],
|
|
75
|
-
state["on_resume_invoke"],
|
|
76
|
-
config={
|
|
77
|
-
# this is only for this run / thread_id
|
|
78
|
-
"configurable": {"_credentials": _credentials}
|
|
79
|
-
},
|
|
80
|
-
command=Command(resume={"status": state["status"]})
|
|
81
|
-
)
|
|
82
|
-
except Exception as e:
|
|
83
|
-
print(f"Error in resume_agent: {e}")
|
|
84
|
-
|
|
85
|
-
return state
|
|
86
|
-
|
|
87
|
-
async def should_continue(state: State):
|
|
88
|
-
status = state.get("status")
|
|
89
|
-
if status == CibaAuthorizerCheckResponse.PENDING:
|
|
90
|
-
return END
|
|
91
|
-
elif status == CibaAuthorizerCheckResponse.EXPIRED:
|
|
92
|
-
return "stop_scheduler"
|
|
93
|
-
elif status in [CibaAuthorizerCheckResponse.APPROVED, CibaAuthorizerCheckResponse.REJECTED]:
|
|
94
|
-
return "resume_agent"
|
|
95
|
-
return END
|
|
96
|
-
|
|
97
|
-
state_graph = StateGraph(State)
|
|
98
|
-
state_graph.add_node("check_status", check_status)
|
|
99
|
-
state_graph.add_node("stop_scheduler", stop_scheduler)
|
|
100
|
-
state_graph.add_node("resume_agent", resume_agent)
|
|
101
|
-
state_graph.add_edge(START, "check_status")
|
|
102
|
-
state_graph.add_edge("resume_agent", "stop_scheduler")
|
|
103
|
-
state_graph.add_conditional_edges("check_status", should_continue)
|
|
104
|
-
|
|
105
|
-
return state_graph
|
auth0_ai_langchain/ciba/types.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
auth0_ai_langchain/FGARetriever.py,sha256=6nQXRkbDLHZt9zYZJsS5iQljrogQVLW0aVwDIf6Mpac,6002
|
|
2
|
-
auth0_ai_langchain/__init__.py,sha256=I331Kz-q97ZU7TfXaOR5UBbJamGEJ15twbf2HP1iCHs,67
|
|
3
|
-
auth0_ai_langchain/auth0_ai.py,sha256=8NUV_80SxR8qQt_3RQGf0Oga178kChuROHuhz7rfOyU,1919
|
|
4
|
-
auth0_ai_langchain/ciba/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
-
auth0_ai_langchain/ciba/ciba_graph/ciba_graph.py,sha256=Wi7qXSMzvcfqdO8WsUJejmQzOVM469TFJCkH7eRlaR8,4115
|
|
6
|
-
auth0_ai_langchain/ciba/ciba_graph/initialize_ciba.py,sha256=a41KedBzxfLqG2AhvkFnemcEqWwQWVh1r-Oro-qJX-M,3752
|
|
7
|
-
auth0_ai_langchain/ciba/ciba_graph/initialize_hitl.py,sha256=CR3jMolZYYOBHx1AXb6yERBaZThMg677qGGo_vRy6I8,1901
|
|
8
|
-
auth0_ai_langchain/ciba/ciba_graph/types.py,sha256=NZS99vPOgJRc2O7mO5MWj6nAa4RSG1R5oSvzYhkz0RA,4234
|
|
9
|
-
auth0_ai_langchain/ciba/ciba_graph/utils.py,sha256=ZPAh0Gs7Hj59_xngg8M7yx1v52dTn2pNpMNRpFKCSII,560
|
|
10
|
-
auth0_ai_langchain/ciba/ciba_poller_graph.py,sha256=rjlxsTheJrN2J6E5BCE9aVyDO8V7UFhdKyX4KT_hB8k,3828
|
|
11
|
-
auth0_ai_langchain/ciba/types.py,sha256=gybqYEprklZwcMBgaWFooBsJ1GcNUK8ZWRvAX5PZWdE,177
|
|
12
|
-
auth0_ai_langchain/federated_connections/__init__.py,sha256=OJCWTnxYuaWjn8FyFGjCAF7m5Y4Eigkzx7a59atFFFg,356
|
|
13
|
-
auth0_ai_langchain/federated_connections/federated_connection_authorizer.py,sha256=ZLF4p7fPTrODOeHWIShfBTsReSy27u73rIp0L_Umhjg,2218
|
|
14
|
-
auth0_ai_langchain/fga/fga_authorizer.py,sha256=uDaGDSXaxQd1X-2w2zTvnfizMB-DtQ-1G6SIaDNBrho,137
|
|
15
|
-
auth0_ai_langchain/utils/interrupt.py,sha256=JoYJkigDEAPRHZtjo6gw6k3439E4i1O7F4_0ExkL_RE,405
|
|
16
|
-
auth0_ai_langchain-0.1.2.dist-info/LICENSE,sha256=Lu_2YH0oK8b_VVisAhNQ2WIdtwY8pSU2PLbll-y6Cj8,9792
|
|
17
|
-
auth0_ai_langchain-0.1.2.dist-info/METADATA,sha256=synK995x8fRtUFH2yR0oyh5vwGzZTEdC2zUJygtzAwQ,7786
|
|
18
|
-
auth0_ai_langchain-0.1.2.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
|
19
|
-
auth0_ai_langchain-0.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|