auth0-ai-langchain 0.2.0__py3-none-any.whl → 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auth0-ai-langchain might be problematic. Click here for more details.

@@ -32,7 +32,7 @@ class FGARetriever(BaseRetriever):
32
32
  Args:
33
33
  retriever (BaseRetriever): The retriever used to fetch documents.
34
34
  build_query (Callable[[Document], ClientBatchCheckItem]): Function to convert documents into FGA queries.
35
- fga_configuration (Optional[ClientConfiguration]): Configuration for the OpenFGA client. If not provided, defaults to environment variables.
35
+ fga_configuration (ClientConfiguration, optional): Configuration for the OpenFGA client. If not provided, defaults to environment variables.
36
36
  """
37
37
  super().__init__()
38
38
  self._retriever = retriever
@@ -95,7 +95,7 @@ class FGARetriever(BaseRetriever):
95
95
 
96
96
  Args:
97
97
  query (str): The query for retrieving documents.
98
- run_manager (Optional[object]): Optional manager for tracking runs.
98
+ run_manager (object, optional): Optional manager for tracking runs.
99
99
 
100
100
  Returns:
101
101
  List[Document]: Filtered and relevant documents.
@@ -148,7 +148,7 @@ class FGARetriever(BaseRetriever):
148
148
 
149
149
  Args:
150
150
  query (str): The query for retrieving documents.
151
- run_manager (Optional[object]): Optional manager for tracking runs.
151
+ run_manager (object, optional): Optional manager for tracking runs.
152
152
 
153
153
  Returns:
154
154
  List[Document]: Filtered and relevant documents.
@@ -1,43 +1,112 @@
1
1
  from typing import Callable, Optional
2
- from langchain_core.runnables.config import RunnableConfig
3
2
  from langchain_core.tools import BaseTool
4
- from auth0_ai.credentials import Credential
5
- from auth0_ai.authorizers.types import AuthorizerParams
6
- from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerParams
7
- from .federated_connections.federated_connection_authorizer import FederatedConnectionAuthorizer
8
- from .ciba.ciba_graph.ciba_graph import CIBAGraph
9
- from .ciba.ciba_graph.types import CIBAGraphOptions
10
-
11
- def get_access_token(config: RunnableConfig) -> Credential:
12
- """
13
- Fetch the access token obtained during the CIBA flow.
3
+ from auth0_ai.authorizers.ciba import CIBAAuthorizerParams
4
+ from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerParams
5
+ from auth0_ai.authorizers.types import Auth0ClientParams
6
+ from auth0_ai_langchain.ciba.ciba_authorizer import CIBAAuthorizer
7
+ from auth0_ai_langchain.federated_connections.federated_connection_authorizer import FederatedConnectionAuthorizer
8
+
14
9
 
15
- Attributes:
16
- config(RunnableConfig): LangGraph runnable configuration instance.
10
+ class Auth0AI:
11
+ """Provides decorators to secure LangChain tools using Auth0 authorization flows.
17
12
  """
18
- return config.get("configurable", {}).get("_credentials", {}).get("access_token")
19
13
 
20
- class Auth0AI():
21
- def __init__(self, config: Optional[AuthorizerParams] = None):
22
- self._graph: Optional[CIBAGraph] = None
23
- self.config = config
14
+ def __init__(self, auth0: Optional[Auth0ClientParams] = None):
15
+ """Initializes the Auth0AI instance.
24
16
 
25
- def with_async_user_confirmation(self, **options: CIBAGraphOptions) -> CIBAGraph:
17
+ Args:
18
+ auth0 (Optional[Auth0ClientParams]): Parameters for the Auth0 client.
19
+ If not provided, values will be automatically read from environment
20
+ variables: `AUTH0_DOMAIN`, `AUTH0_CLIENT_ID`, and `AUTH0_CLIENT_SECRET`.
26
21
  """
27
- Initializes and registers a state graph for conditional trade operations using CIBA.
22
+ self.auth0 = auth0
28
23
 
29
- Attributes:
30
- options (Optional[CIBAGraphOptions]): The base CIBA options.
31
- """
32
- self._graph = CIBAGraph(CIBAGraphOptions(**options), self.config)
33
- return self._graph
34
-
35
- def with_federated_connection(self, **options: FederatedConnectionAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
24
+ def with_async_user_confirmation(self, **params: CIBAAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
25
+ """Protects a tool with the CIBA (Client-Initiated Backchannel Authentication) flow.
26
+
27
+ Requires user confirmation via a second device (e.g., phone)
28
+ before allowing the tool to execute.
29
+
30
+ Args:
31
+ **params: Parameters defined in `CIBAAuthorizerParams`.
32
+
33
+ Returns:
34
+ Callable[[BaseTool], BaseTool]: A decorator to wrap a LangChain tool.
35
+
36
+ Example:
37
+ ```python
38
+ import os
39
+ from auth0_ai_langchain.auth0_ai import Auth0AI
40
+ from auth0_ai_langchain.ciba import get_ciba_credentials
41
+ from langchain_core.runnables import ensure_config
42
+ from langchain_core.tools import StructuredTool
43
+
44
+ auth0_ai = Auth0AI()
45
+
46
+ with_async_user_confirmation = auth0_ai.with_async_user_confirmation(
47
+ scope="stock:trade",
48
+ audience=os.getenv("AUDIENCE"),
49
+ binding_message=lambda ticker, qty: f"Authorize the purchase of {qty} {ticker}",
50
+ user_id=lambda *_, **__: ensure_config().get("configurable", {}).get("user_id")
51
+ )
52
+
53
+ def tool_function(ticker: str, qty: int) -> str:
54
+ credentials = get_ciba_credentials()
55
+ headers = {
56
+ "Authorization": f"{credentials['token_type']} {credentials['access_token']}",
57
+ # ...
58
+ }
59
+ # Call API
60
+
61
+ trade_tool = with_async_user_confirmation(
62
+ StructuredTool(
63
+ name="trade_tool",
64
+ description="Use this function to trade a stock",
65
+ func=tool_function,
66
+ )
67
+ )
68
+ ```
36
69
  """
37
- Protects a tool execution with the Federated Connection authorizer.
70
+ authorizer = CIBAAuthorizer(CIBAAuthorizerParams(**params), self.auth0)
71
+ return authorizer.authorizer()
72
+
73
+ def with_federated_connection(self, **params: FederatedConnectionAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
74
+ """Enables a tool to obtain an access token from a federated identity provider (e.g., Google, Azure AD).
75
+
76
+ The token can then be used within the tool to call third-party APIs on behalf of the user.
77
+
78
+ Args:
79
+ **params: Parameters defined in `FederatedConnectionAuthorizerParams`.
80
+
81
+ Returns:
82
+ Callable[[BaseTool], BaseTool]: A decorator to wrap a LangChain tool.
83
+
84
+ Example:
85
+ ```python
86
+ from auth0_ai_langchain.auth0_ai import Auth0AI
87
+ from auth0_ai_langchain.federated_connections import get_credentials_for_connection
88
+ from langchain_core.tools import StructuredTool
89
+ from datetime import datetime
90
+
91
+ auth0_ai = Auth0AI()
92
+
93
+ with_google_calendar_access = auth0_ai.with_federated_connection(
94
+ connection="google-oauth2",
95
+ scopes=["https://www.googleapis.com/auth/calendar.freebusy"]
96
+ )
97
+
98
+ def tool_function(date: datetime):
99
+ credentials = get_credentials_for_connection()
100
+ # Call Google API using credentials["access_token"]
38
101
 
39
- Attributes:
40
- options (FederatedConnectionAuthorizerParams): The Federated Connections authorizer options.
102
+ check_calendar_tool = with_google_calendar_access(
103
+ StructuredTool(
104
+ name="check_user_calendar",
105
+ description="Use this function to check if the user is available on a certain date and time",
106
+ func=tool_function,
107
+ )
108
+ )
109
+ ```
41
110
  """
42
- authorizer = FederatedConnectionAuthorizer(FederatedConnectionAuthorizerParams(**options), self.config)
111
+ authorizer = FederatedConnectionAuthorizer(FederatedConnectionAuthorizerParams(**params), self.auth0)
43
112
  return authorizer.authorizer()
@@ -0,0 +1,3 @@
1
+ from auth0_ai.authorizers.ciba.ciba_authorizer_base import get_ciba_credentials as get_ciba_credentials
2
+ from auth0_ai_langchain.ciba.ciba_authorizer import CIBAAuthorizer as CIBAAuthorizer
3
+ from auth0_ai_langchain.ciba.graph_resumer import GraphResumer as GraphResumer
@@ -0,0 +1,17 @@
1
+ from abc import ABC
2
+ from typing import Union
3
+ from auth0_ai.authorizers.ciba import CIBAAuthorizerBase
4
+ from auth0_ai.interrupts.ciba_interrupts import AuthorizationPendingInterrupt, AuthorizationPollingInterrupt
5
+ from auth0_ai_langchain.utils.interrupt import to_graph_interrupt
6
+ from auth0_ai_langchain.utils.tool_wrapper import tool_wrapper
7
+ from langchain_core.tools import BaseTool
8
+
9
+ class CIBAAuthorizer(CIBAAuthorizerBase, ABC):
10
+ def _handle_authorization_interrupts(self, err: Union[AuthorizationPendingInterrupt, AuthorizationPollingInterrupt]) -> None:
11
+ raise to_graph_interrupt(err)
12
+
13
+ def authorizer(self):
14
+ def wrap_tool(tool: BaseTool) -> BaseTool:
15
+ return tool_wrapper(tool, self.protect)
16
+
17
+ return wrap_tool
@@ -0,0 +1,154 @@
1
+ import asyncio
2
+ from threading import Event
3
+ from typing import Callable, Optional, Dict, Any, List, TypedDict
4
+ from auth0_ai.authorizers.ciba import CIBAAuthorizationRequest
5
+ from auth0_ai.interrupts.ciba_interrupts import CIBAInterrupt, AuthorizationPendingInterrupt, AuthorizationPollingInterrupt
6
+ from auth0_ai_langchain.utils.interrupt import get_auth0_interrupts
7
+ from langgraph_sdk.client import LangGraphClient
8
+ from langgraph_sdk.schema import Thread, Interrupt
9
+
10
+ class WatchedThread(TypedDict):
11
+ thread_id: str
12
+ assistant_id: str
13
+ interruption_id: str
14
+ auth_request: CIBAAuthorizationRequest
15
+ config: Dict[str, Any]
16
+ last_run: float
17
+
18
+ class GraphResumerFilters(TypedDict):
19
+ graph_id: str
20
+
21
+ class GraphResumer:
22
+ def __init__(self, lang_graph: LangGraphClient, filters: Optional[GraphResumerFilters] = None):
23
+ self.lang_graph = lang_graph
24
+ self.filters = filters or {}
25
+ self.map: Dict[str, WatchedThread] = {}
26
+ self._stop_event = Event()
27
+ self._loop_task: Optional[asyncio.Task] = None
28
+
29
+ # Event callbacks
30
+ self._resume_callbacks: List[Callable[[WatchedThread], None]] = []
31
+ self._error_callbacks: List[Callable[[Exception], None]] = []
32
+
33
+ # Public API to register event callbacks
34
+ def on_resume(self, callback: Callable[[WatchedThread], None]) -> "GraphResumer":
35
+ self._resume_callbacks.append(callback)
36
+ return self
37
+
38
+ def on_error(self, callback: Callable[[Exception], None]) -> "GraphResumer":
39
+ self._error_callbacks.append(callback)
40
+ return self
41
+
42
+ def _emit_resume(self, thread: WatchedThread) -> None:
43
+ for callback in self._resume_callbacks:
44
+ callback(thread)
45
+
46
+ def _emit_error(self, error: Exception) -> None:
47
+ for callback in self._error_callbacks:
48
+ callback(error)
49
+
50
+ async def _get_all_interrupted_threads(self) -> List[Thread]:
51
+ interrupted_threads: List[Thread] = []
52
+ offset = 0
53
+
54
+ while True:
55
+ page = await self.lang_graph.threads.search(
56
+ status="interrupted",
57
+ limit=100,
58
+ offset=offset,
59
+ metadata={"graph_id": self.filters["graph_id"]} if "graph_id" in self.filters else None
60
+ )
61
+
62
+ if not page:
63
+ break
64
+
65
+ for t in page:
66
+ interrupt = self._get_first_interrupt(t)
67
+ if interrupt and CIBAInterrupt.is_interrupt(interrupt["value"]) and CIBAInterrupt.has_request_data(interrupt["value"]):
68
+ interrupted_threads.append(t)
69
+
70
+ offset += len(page)
71
+ if len(page) < 100:
72
+ break
73
+
74
+ return interrupted_threads
75
+
76
+ def _get_first_interrupt(self, thread: Thread) -> Optional[Interrupt]:
77
+ interrupts = thread["interrupts"]
78
+ if interrupts:
79
+ values = list(interrupts.values())
80
+ if values and values[0]:
81
+ return values[0][0]
82
+ return None
83
+
84
+ def _get_hash_map_id(self, thread: Thread) -> str:
85
+ return f"{thread['thread_id']}:{next(iter(thread['interrupts']))}"
86
+
87
+ async def _resume_thread(self, t: WatchedThread):
88
+ self._emit_resume(t)
89
+
90
+ await self.lang_graph.runs.wait(t["thread_id"], t["assistant_id"], config=t["config"])
91
+
92
+ t["last_run"] = asyncio.get_event_loop().time() * 1000
93
+
94
+ async def loop(self):
95
+ all_threads = await self._get_all_interrupted_threads()
96
+
97
+ # Remove old interrupted threads
98
+ active_keys = {self._get_hash_map_id(t) for t in all_threads}
99
+
100
+ for key in list(self.map.keys()):
101
+ if key not in active_keys:
102
+ del self.map[key]
103
+
104
+ # Add new interrupted threads
105
+ for thread in all_threads:
106
+ interrupt = next(
107
+ (i for i in get_auth0_interrupts(thread)
108
+ if AuthorizationPendingInterrupt.is_interrupt(i["value"])
109
+ or AuthorizationPollingInterrupt.is_interrupt(i["value"])),
110
+ None
111
+ )
112
+
113
+ if not interrupt or not interrupt["value"].get("_request"):
114
+ continue
115
+
116
+ key = self._get_hash_map_id(thread)
117
+ if key not in self.map:
118
+ self.map[key] = {
119
+ "thread_id": thread["thread_id"],
120
+ "assistant_id": thread["metadata"].get("graph_id"),
121
+ "config": getattr(thread, "config", {}),
122
+ "interruption_id": next(iter(thread["interrupts"])),
123
+ "auth_request": interrupt["value"]["_request"],
124
+ }
125
+
126
+ threads_to_resume = [
127
+ t for t in self.map.values()
128
+ if "last_run" not in t or (t["last_run"] + t["auth_request"]["interval"] * 1000 < asyncio.get_event_loop().time() * 1000)
129
+ ]
130
+
131
+ await asyncio.gather(*[
132
+ self._resume_thread(t) for t in threads_to_resume
133
+ ])
134
+
135
+ def start(self):
136
+ if self._loop_task and not self._loop_task.done():
137
+ return
138
+
139
+ self._stop_event.clear()
140
+
141
+ async def _run_loop():
142
+ while not self._stop_event.is_set():
143
+ try:
144
+ await self.loop()
145
+ except Exception as e:
146
+ self._emit_error(e)
147
+ await asyncio.sleep(5)
148
+
149
+ self._loop_task = asyncio.create_task(_run_loop())
150
+
151
+ def stop(self):
152
+ self._stop_event.set()
153
+ if self._loop_task:
154
+ self._loop_task.cancel()
@@ -1,4 +1,7 @@
1
- from auth0_ai.interrupts.federated_connection_interrupt import FederatedConnectionError as FederatedConnectionError
2
- from auth0_ai.interrupts.federated_connection_interrupt import FederatedConnectionInterrupt as FederatedConnectionInterrupt
3
- from auth0_ai.authorizers.federated_connection_authorizer import get_access_token_for_connection as get_access_token_for_connection
1
+ from auth0_ai.interrupts.federated_connection_interrupt import (
2
+ FederatedConnectionError as FederatedConnectionError,
3
+ FederatedConnectionInterrupt as FederatedConnectionInterrupt
4
+ )
5
+
6
+ from auth0_ai.authorizers.federated_connection_authorizer import get_credentials_for_connection as get_credentials_for_connection
4
7
  from .federated_connection_authorizer import FederatedConnectionAuthorizer as FederatedConnectionAuthorizer
@@ -1,52 +1,33 @@
1
1
  import copy
2
2
  from abc import ABC
3
3
  from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerBase, FederatedConnectionAuthorizerParams
4
- from auth0_ai.authorizers.types import AuthorizerParams
4
+ from auth0_ai.authorizers.types import Auth0ClientParams
5
5
  from auth0_ai.interrupts.federated_connection_interrupt import FederatedConnectionInterrupt
6
- from langchain_core.tools import BaseTool, tool
6
+ from auth0_ai_langchain.utils.interrupt import to_graph_interrupt
7
+ from auth0_ai_langchain.utils.tool_wrapper import tool_wrapper
8
+ from langchain_core.tools import BaseTool
7
9
  from langchain_core.runnables import ensure_config
8
- from ..utils.interrupt import to_graph_interrupt
9
10
 
10
- async def get_refresh_token(*_args, **_kwargs) -> str | None:
11
+ async def default_get_refresh_token(*_, **__) -> str | None:
11
12
  return ensure_config().get("configurable", {}).get("_credentials", {}).get("refresh_token")
12
13
 
13
14
  class FederatedConnectionAuthorizer(FederatedConnectionAuthorizerBase, ABC):
14
15
  def __init__(
15
16
  self,
16
- options: FederatedConnectionAuthorizerParams,
17
- config: AuthorizerParams = None,
17
+ params: FederatedConnectionAuthorizerParams,
18
+ auth0: Auth0ClientParams = None,
18
19
  ):
19
- if options.refresh_token.value is None:
20
- options = copy.copy(options)
21
- options.refresh_token.value = get_refresh_token
20
+ if params.refresh_token.value is None:
21
+ params = copy.copy(params)
22
+ params.refresh_token.value = default_get_refresh_token
22
23
 
23
- super().__init__(options, config)
24
+ super().__init__(params, auth0)
24
25
 
25
26
  def _handle_authorization_interrupts(self, err: FederatedConnectionInterrupt) -> None:
26
27
  raise to_graph_interrupt(err)
27
28
 
28
29
  def authorizer(self):
29
- def wrapped_tool(t: BaseTool) -> BaseTool:
30
- async def execute_fn(*_args, **kwargs):
31
- return await t.ainvoke(input=kwargs)
32
-
33
- tool_fn = self.protect(
34
- lambda *_args, **_kwargs: {
35
- "thread_id": ensure_config().get("configurable", {}).get("thread_id"),
36
- "checkpoint_ns": ensure_config().get("configurable", {}).get("checkpoint_ns"),
37
- "run_id": ensure_config().get("configurable", {}).get("run_id"),
38
- "tool_call_id": ensure_config().get("configurable", {}).get("tool_call_id"), # TODO: review this
39
- },
40
- execute_fn
41
- )
42
- tool_fn.__name__ = t.name
43
-
44
- return tool(
45
- tool_fn,
46
- description=t.description,
47
- return_direct=t.return_direct,
48
- args_schema=t.args_schema,
49
- response_format=t.response_format,
50
- )
30
+ def wrap_tool(tool: BaseTool) -> BaseTool:
31
+ return tool_wrapper(tool, self.protect)
51
32
 
52
- return wrapped_tool
33
+ return wrap_tool
@@ -0,0 +1,4 @@
1
+ from auth0_ai.authorizers.fga_authorizer import (
2
+ FGAAuthorizer as FGAAuthorizer,
3
+ FGAAuthorizerOptions as FGAAuthorizerOptions
4
+ )
@@ -1,6 +1,9 @@
1
+ from typing import List
1
2
  from auth0_ai.interrupts.auth0_interrupt import Auth0Interrupt
2
3
  from langgraph.errors import GraphInterrupt
3
4
  from langgraph.types import Interrupt
5
+ from langgraph_sdk.schema import Thread
6
+
4
7
 
5
8
  def to_graph_interrupt(interrupt: Auth0Interrupt) -> GraphInterrupt:
6
9
  return GraphInterrupt([
@@ -8,6 +11,20 @@ def to_graph_interrupt(interrupt: Auth0Interrupt) -> GraphInterrupt:
8
11
  value=interrupt.to_json(),
9
12
  when="during",
10
13
  resumable=True,
11
- ns=[f"auth0AI:{interrupt.__class__.__name__}:{interrupt.code}"]
14
+ ns=[f"auth0AI:{interrupt.name}:{interrupt.code}"]
12
15
  )
13
16
  ])
17
+
18
+
19
+ def get_auth0_interrupts(thread: Thread) -> List[Interrupt]:
20
+ result = []
21
+
22
+ if "interrupts" not in thread:
23
+ return result
24
+
25
+ for interrupt_list in thread["interrupts"].values():
26
+ for interrupt in interrupt_list:
27
+ if Auth0Interrupt.is_interrupt(interrupt["value"]):
28
+ result.append(interrupt)
29
+
30
+ return result
@@ -0,0 +1,34 @@
1
+ from typing import Callable
2
+ from typing_extensions import Annotated
3
+ from pydantic import create_model
4
+ from langchain_core.tools import BaseTool, tool as create_tool, InjectedToolCallId
5
+ from langchain_core.runnables import RunnableConfig
6
+
7
+ def tool_wrapper(tool: BaseTool, protect_fn: Callable) -> BaseTool:
8
+
9
+ # Workaround: extend existing args_schema to be able to get the tool_call_id value
10
+ args_schema = create_model(
11
+ tool.args_schema.__name__ + "Extended",
12
+ __base__=tool.args_schema,
13
+ **{"tool_call_id": (Annotated[str, InjectedToolCallId])}
14
+ )
15
+
16
+ @create_tool(
17
+ tool.name,
18
+ description=tool.description,
19
+ args_schema=args_schema
20
+ )
21
+ async def wrapped_tool(config: RunnableConfig, tool_call_id: Annotated[str, InjectedToolCallId], **input):
22
+ async def execute_fn(*_, **__):
23
+ return await tool.ainvoke(input, config)
24
+
25
+ return await protect_fn(
26
+ lambda *_, **__: {
27
+ "thread_id": config.get("configurable", {}).get("thread_id"),
28
+ "tool_call_id": tool_call_id,
29
+ "tool_name": tool.name,
30
+ },
31
+ execute_fn,
32
+ )(**input)
33
+
34
+ return wrapped_tool
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: auth0-ai-langchain
3
- Version: 0.2.0
3
+ Version: 1.0.0b1
4
4
  Summary: This package is an SDK for building secure AI-powered applications using Auth0, Okta FGA and LangChain.
5
5
  License: Apache-2.0
6
6
  Author: Auth0
@@ -11,12 +11,12 @@ Classifier: Programming Language :: Python :: 3
11
11
  Classifier: Programming Language :: Python :: 3.11
12
12
  Classifier: Programming Language :: Python :: 3.12
13
13
  Classifier: Programming Language :: Python :: 3.13
14
- Requires-Dist: auth0-ai (>=0.2.0,<0.3.0)
15
- Requires-Dist: langchain (>=0.3.20,<0.4.0)
16
- Requires-Dist: langchain-core (>=0.3.43,<0.4.0)
17
- Requires-Dist: langgraph (>=0.3.25,<0.4.0)
18
- Requires-Dist: langgraph-sdk (>=0.1.55,<0.2.0)
19
- Requires-Dist: openfga-sdk (>=0.9.0,<0.10.0)
14
+ Requires-Dist: auth0-ai (>=1.0.0b1,<2.0.0)
15
+ Requires-Dist: langchain (>=0.3.25,<0.4.0)
16
+ Requires-Dist: langchain-core (>=0.3.59,<0.4.0)
17
+ Requires-Dist: langgraph (>=0.4.3,<0.5.0)
18
+ Requires-Dist: langgraph-sdk (>=0.1.66,<0.2.0)
19
+ Requires-Dist: openfga-sdk (>=0.9.4,<0.10.0)
20
20
  Project-URL: Homepage, https://auth0.com
21
21
  Description-Content-Type: text/markdown
22
22
 
@@ -34,6 +34,52 @@ Description-Content-Type: text/markdown
34
34
  pip install auth0-ai-langchain
35
35
  ```
36
36
 
37
+ ## Async User Confirmation
38
+
39
+ `Auth0AI` uses CIBA (Client-Initiated Backchannel Authentication) to handle user confirmation asynchronously. This is useful when you need to confirm a user action before proceeding with a tool execution.
40
+
41
+ Full Example of [Async User Confirmation](https://github.com/auth0-lab/auth0-ai-python/tree/main/examples/async-user-confirmation/langchain-examples).
42
+
43
+ 1. Define a tool with the proper authorizer specifying a function to resolve the user id:
44
+
45
+ ```python
46
+ from auth0_ai_langchain.auth0_ai import Auth0AI
47
+ from auth0_ai_langchain.ciba import get_ciba_credentials
48
+ from langchain_core.runnables import ensure_config
49
+ from langchain_core.tools import StructuredTool
50
+
51
+ # If not provided, Auth0 settings will be read from env variables: `AUTH0_DOMAIN`, `AUTH0_CLIENT_ID`, and `AUTH0_CLIENT_SECRET`
52
+ auth0_ai = Auth0AI()
53
+
54
+ with_async_user_confirmation = auth0_ai.with_async_user_confirmation(
55
+ scope="stock:trade",
56
+ audience=os.getenv("AUDIENCE"),
57
+ binding_message=lambda ticker, qty: f"Authorize the purchase of {qty} {ticker}",
58
+ user_id=lambda *_, **__: ensure_config().get("configurable", {}).get("user_id"),
59
+ # Optional:
60
+ # store=InMemoryStore()
61
+ )
62
+
63
+ def tool_function(ticker: str, qty: int) -> str:
64
+ credentials = get_ciba_credentials()
65
+ headers = {
66
+ "Authorization": f"{credentials["token_type"]} {credentials["access_token"]}",
67
+ # ...
68
+ }
69
+ # Call API
70
+
71
+ trade_tool = with_async_user_confirmation(
72
+ StructuredTool(
73
+ name="trade_tool",
74
+ description="Use this function to trade a stock",
75
+ func=trade_tool_function,
76
+ # ...
77
+ )
78
+ )
79
+ ```
80
+
81
+ 2. Handle interruptions properly. For example, if user is not enrolled to MFA, it will throw an interruption. See [Handling Interrupts](#handling-interrupts) section.
82
+
37
83
  ## Authorization for Tools
38
84
 
39
85
  The `FGAAuthorizer` can leverage Okta FGA to authorize tools executions. The `FGAAuthorizer.create` function can be used to create an authorizer that checks permissions before executing the tool.
@@ -43,19 +89,12 @@ Full example of [Authorization for Tools](https://github.com/auth0-lab/auth0-ai-
43
89
  1. Create an instance of FGA Authorizer:
44
90
 
45
91
  ```python
46
- from auth0_ai_langchain.fga.fga_authorizer import FGAAuthorizer, FGAAuthorizerOptions
92
+ from auth0_ai_langchain.fga import FGAAuthorizer
47
93
 
94
+ # If not provided, FGA settings will be read from env variables: `FGA_STORE_ID`, `FGA_CLIENT_ID`, `FGA_CLIENT_SECRET`, etc.
48
95
  fga = FGAAuthorizer.create()
49
96
  ```
50
97
 
51
- **Note**: Here, you can configure and specify your FGA credentials. By `default`, they are read from environment variables:
52
-
53
- ```sh
54
- FGA_STORE_ID="<fga-store-id>"
55
- FGA_CLIENT_ID="<fga-client-id>"
56
- FGA_CLIENT_SECRET="<fga-client-secret>"
57
- ```
58
-
59
98
  2. Define the FGA query (`build_query`) and, optionally, the `on_unauthorized` handler:
60
99
 
61
100
  ```python
@@ -73,10 +112,10 @@ async def build_fga_query(tool_input):
73
112
  def on_unauthorized(tool_input):
74
113
  return f"The user is not allowed to buy {tool_input["qty"]} shares of {tool_input["ticker"]}."
75
114
 
76
- use_fga = fga(FGAAuthorizerOptions(
115
+ use_fga = fga(
77
116
  build_query=build_fga_query,
78
117
  on_unauthorized=on_unauthorized,
79
- ))
118
+ )
80
119
  ```
81
120
 
82
121
  **Note**: The parameters given to the `build_query` and `on_unauthorized` functions are the same as those provided to the tool function.
@@ -102,7 +141,7 @@ buy_tool = StructuredTool(
102
141
 
103
142
  ## Calling APIs On User's Behalf
104
143
 
105
- The `Auth0AI.with_federated_connection` function exchanges user's refresh token taken from the runnable configuration (`config.configurable._credentials.refresh_token`) for a Federated Connection API token.
144
+ The `Auth0AI.with_federated_connection` function exchanges user's refresh token taken, by default, from the runnable configuration (`config.configurable._credentials.refresh_token`) for a Federated Connection API token.
106
145
 
107
146
  Full Example of [Calling APIs On User's Behalf](https://github.com/auth0-lab/auth0-ai-python/tree/main/examples/calling-apis/langchain-examples).
108
147
 
@@ -110,19 +149,23 @@ Full Example of [Calling APIs On User's Behalf](https://github.com/auth0-lab/aut
110
149
 
111
150
  ```python
112
151
  from auth0_ai_langchain.auth0_ai import Auth0AI
113
- from auth0_ai_langchain.federated_connections import get_access_token_for_connection
152
+ from auth0_ai_langchain.federated_connections import get_credentials_for_connection
114
153
  from langchain_core.tools import StructuredTool
115
154
 
155
+ # If not provided, Auth0 settings will be read from env variables: `AUTH0_DOMAIN`, `AUTH0_CLIENT_ID`, and `AUTH0_CLIENT_SECRET`
116
156
  auth0_ai = Auth0AI()
117
157
 
118
158
  with_google_calendar_access = auth0_ai.with_federated_connection(
119
159
  connection="google-oauth2",
120
- scopes=["https://www.googleapis.com/auth/calendar.freebusy"]
160
+ scopes=["https://www.googleapis.com/auth/calendar.freebusy"],
161
+ # Optional:
162
+ # refresh_token=lambda *_, **__: ensure_config().get("configurable", {}).get("_credentials", {}).get("refresh_token"),
163
+ # store=InMemoryStore(),
121
164
  )
122
165
 
123
166
  def tool_function(date: datetime):
124
- access_token = get_access_token_for_connection()
125
- # Call Google API
167
+ credentials = get_credentials_for_connection()
168
+ # Call Google API using credentials["access_token"]
126
169
 
127
170
  check_calendar_tool = with_google_calendar_access(
128
171
  StructuredTool(
@@ -154,7 +197,7 @@ workflow = (
154
197
  )
155
198
  ```
156
199
 
157
- 3. Handle interruptions properly. If the tool does not have access to user's Google Calendar, it will throw an interruption.
200
+ 3. Handle interruptions properly. For example, if the tool does not have access to user's Google Calendar, it will throw an interruption. See [Handling Interrupts](#handling-interrupts) section.
158
201
 
159
202
  ## RAG with FGA
160
203
 
@@ -184,7 +227,8 @@ vector_store = VectorStoreIndex.from_documents(documents)
184
227
  # Create a retriever:
185
228
  base_retriever = vector_store.as_retriever()
186
229
 
187
- # Create the FGA retriever wrapper:
230
+ # Create the FGA retriever wrapper.
231
+ # If not provided, FGA settings will be read from env variables: `FGA_STORE_ID`, `FGA_CLIENT_ID`, `FGA_CLIENT_SECRET`, etc.
188
232
  retriever = FGARetriever(
189
233
  base_retriever,
190
234
  build_query=lambda node: ClientCheckRequest(
@@ -206,6 +250,63 @@ response = query_engine.query("What is the forecast for ZEKO?")
206
250
  print(response)
207
251
  ```
208
252
 
253
+ ## Handling Interrupts
254
+
255
+ `Auth0AI` uses interrupts extensively and will never block a graph. Whenever an authorizer requires user interaction, the graph throws a `GraphInterrupt` exception with data that allows the client to resume the flow.
256
+
257
+ It is important to disable error handling in your tools node as follows:
258
+
259
+ ```python
260
+ .add_node(
261
+ "tools",
262
+ ToolNode(
263
+ [
264
+ # your authorizer-wrapped tools
265
+ ],
266
+ # Error handler should be disabled in order to trigger interruptions from within tools.
267
+ handle_tool_errors=False
268
+ )
269
+ )
270
+ ```
271
+
272
+ From the client side of the graph you get the interrupts:
273
+
274
+ ```python
275
+ from auth0_ai_langchain.utils.interrupt import get_auth0_interrupts
276
+
277
+ # Get the langgraph thread:
278
+ thread = await client.threads.get(thread_id)
279
+
280
+ # Filter the auth0 interrupts:
281
+ auth0_interrupts = get_auth0_interrupts(thread)
282
+ ```
283
+
284
+ Then you can resume the thread by doing this:
285
+
286
+ ```python
287
+ await client.runs.wait(thread_id, assistant_id)
288
+ ```
289
+
290
+ For the specific case of **CIBA (Client-Initiated Backchannel Authorization)** you might attach a `GraphResumer` instance that watches for interrupted threads in the `"Authorization Pending"` state and attempts to resume them automatically, respecting Auth0's polling interval.
291
+
292
+ ```python
293
+ import os
294
+ from auth0_ai_langchain.ciba import GraphResumer
295
+ from langgraph_sdk import get_client
296
+
297
+ resumer = GraphResumer(
298
+ lang_graph=get_client(url=os.getenv("LANGGRAPH_API_URL")),
299
+ # optionally, you can filter by a specific graph:
300
+ filters={"graph_id": "conditional-trade"},
301
+ )
302
+
303
+ resumer \
304
+ .on_resume(lambda thread: print(f"Attempting to resume thread {thread['thread_id']} from interruption {thread['interruption_id']}")) \
305
+ .on_error(lambda err: print(f"Error in GraphResumer: {str(err)}"))
306
+
307
+ resumer.start()
308
+ ```
309
+
209
310
  ---
210
311
 
211
312
  <p align="center">
@@ -0,0 +1,15 @@
1
+ auth0_ai_langchain/FGARetriever.py,sha256=SQwxo2aDtQhwQtYmszoKw3BH-U5QVnvPAgVw9EDzKVM,6002
2
+ auth0_ai_langchain/__init__.py,sha256=I331Kz-q97ZU7TfXaOR5UBbJamGEJ15twbf2HP1iCHs,67
3
+ auth0_ai_langchain/auth0_ai.py,sha256=J3fxYNZf0KMK2w085dCdGfCRyafQGWPAI19edcYpQi8,4732
4
+ auth0_ai_langchain/ciba/__init__.py,sha256=X62HZB20XdhsgcaKld6rLm2BOSuiO5uU5v7ePQz27Mk,268
5
+ auth0_ai_langchain/ciba/ciba_authorizer.py,sha256=GRAB3NBnmoxAECrRjPNdA9N9uQ4pCEzP6dF8RUwlysM,766
6
+ auth0_ai_langchain/ciba/graph_resumer.py,sha256=EpdzzB_NccdggKA3x__Q3Yziejo7AJjR4aJ57TZmYPA,5474
7
+ auth0_ai_langchain/federated_connections/__init__.py,sha256=kGpPN9ntsyvE-2m_lcdCVvPevadyILCk3NiAm4TN0QA,429
8
+ auth0_ai_langchain/federated_connections/federated_connection_authorizer.py,sha256=o25oRGiTo9y5mpjDNEWWaFVAIFbhwaxC0pcRId-4oYE,1405
9
+ auth0_ai_langchain/fga/__init__.py,sha256=rgqTD4Gvz28jNdqhxTG5udbgyeUMsyvRj83fHBJdt4s,137
10
+ auth0_ai_langchain/utils/interrupt.py,sha256=DZ1b9OAkg3SQru9mSaQGBC6UY0ODz7QSskS9RlVyEGw,860
11
+ auth0_ai_langchain/utils/tool_wrapper.py,sha256=dHjcqykT2aohdFOm0mLZ9U6bXB6NHjfABb3aXef5174,1210
12
+ auth0_ai_langchain-1.0.0b1.dist-info/LICENSE,sha256=Lu_2YH0oK8b_VVisAhNQ2WIdtwY8pSU2PLbll-y6Cj8,9792
13
+ auth0_ai_langchain-1.0.0b1.dist-info/METADATA,sha256=lyc_GI9ymhgIrQh2wM2fv-lYF_uWT9VXFanst8HowEs,11790
14
+ auth0_ai_langchain-1.0.0b1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
15
+ auth0_ai_langchain-1.0.0b1.dist-info/RECORD,,
@@ -1,109 +0,0 @@
1
- from typing import Awaitable, Hashable, List, Optional, Callable, Any, Union
2
- from langchain_core.tools import StructuredTool
3
- from langchain_core.tools.base import BaseTool
4
- from langgraph.graph import StateGraph, END, START
5
- from langchain_core.runnables import Runnable
6
- from auth0_ai.authorizers.types import AuthorizerParams
7
- from ..types import Auth0Nodes
8
- from .initialize_ciba import initialize_ciba
9
- from .initialize_hitl import initialize_hitl
10
- from .types import CIBAGraphOptions, CIBAOptions, ProtectedTool, BaseState
11
-
12
- class CIBAGraph():
13
- def __init__(
14
- self,
15
- options: Optional[CIBAGraphOptions] = None,
16
- authorizer_params: Optional[AuthorizerParams] = None,
17
- ):
18
- self.options = options
19
- self.authorizer_params = authorizer_params
20
- self.tools: List[ProtectedTool] = []
21
- self.graph: Optional[StateGraph] = None
22
-
23
- def get_tools(self) -> List[ProtectedTool]:
24
- return self.tools
25
-
26
- def get_graph(self) -> Optional[StateGraph]:
27
- return self.graph
28
-
29
- def get_options(self) -> Optional[CIBAGraphOptions]:
30
- return self.options
31
-
32
- def get_authorizer_params(self) -> Optional[AuthorizerParams]:
33
- return self.authorizer_params
34
-
35
- def register_nodes(
36
- self,
37
- graph: StateGraph,
38
- ) -> StateGraph:
39
- self.graph = graph
40
-
41
- # Add CIBA HITL and CIBA nodes
42
- self.graph.add_node(Auth0Nodes.AUTH0_CIBA_HITL.value, initialize_hitl(self))
43
- self.graph.add_node(Auth0Nodes.AUTH0_CIBA.value, initialize_ciba(self))
44
- self.graph.add_conditional_edges(
45
- Auth0Nodes.AUTH0_CIBA.value,
46
- lambda state: END if getattr(state, "auth0", {}).get("error") else Auth0Nodes.AUTH0_CIBA_HITL.value,
47
- )
48
-
49
- return graph
50
-
51
- def protect_tool(
52
- self,
53
- tool: Union[BaseTool, Callable],
54
- options: CIBAOptions,
55
- ) -> StructuredTool:
56
- """
57
- Authorize Options to start CIBA flow.
58
-
59
- Attributes:
60
- tool (Union[BaseTool, Callable]): The tool to be protected.
61
- options (CIBAOptions): The CIBA options.
62
- """
63
-
64
- # Merge default options with tool-specific options
65
- merged_options = {**self.options, **options.__dict__} if isinstance(self.options, dict) else {**vars(self.options), **vars(options)}
66
-
67
- if merged_options["on_approve_go_to"] is None:
68
- raise ValueError(f"[{tool.name}] on_approve_go_to is required")
69
-
70
- if merged_options["on_reject_go_to"] is None:
71
- raise ValueError(f"[{tool.name}] on_reject_go_to is required")
72
-
73
- self.tools.append(ProtectedTool(tool_name=tool.name, options=merged_options))
74
-
75
- return tool
76
-
77
- def with_auth(self, path: Union[
78
- Callable[..., Union[Hashable, list[Hashable]]],
79
- Callable[..., Awaitable[Union[Hashable, list[Hashable]]]],
80
- Runnable[Any, Union[Hashable, list[Hashable]]],
81
- ]):
82
- """
83
- A wrapper for the callable that determines the next node or nodes using a protected tool.
84
-
85
- Attributes:
86
- path (Union[Callable[..., Union[Hashable, list[Hashable]]], Callable[..., Awaitable[Union[Hashable, list[Hashable]]]], Runnable[Any, Union[Hashable, list[Hashable]]]])): The callable that determines the next node or nodes using a protected tool.
87
- """
88
- def wrapper(*args):
89
- if not callable(path):
90
- return START
91
-
92
- state: BaseState = args[0]
93
- messages = state.get("messages")
94
- last_message = messages[-1] if messages else None
95
-
96
- # Call default path if there are no tool calls
97
- if not last_message or not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
98
- return path(*args)
99
-
100
- tool_name = last_message.tool_calls[0]["name"]
101
- tool = next((t for t in self.tools if t.tool_name == tool_name), None)
102
-
103
- if tool:
104
- return Auth0Nodes.AUTH0_CIBA.value
105
-
106
- # Call default path if tool is not protected
107
- return path(*args)
108
-
109
- return wrapper
@@ -1,91 +0,0 @@
1
- import os
2
- from langgraph.types import Command
3
- from langgraph_sdk import get_client
4
- from langchain_core.runnables.config import RunnableConfig
5
- from auth0_ai.authorizers.ciba_authorizer import CIBAAuthorizer
6
- from ..types import Auth0Graphs, Auth0Nodes
7
- from .types import ICIBAGraph, BaseState
8
- from .utils import get_tool_definition
9
-
10
- def initialize_ciba(ciba_graph: ICIBAGraph):
11
- async def handler(state: BaseState, config: RunnableConfig):
12
- try:
13
- ciba_params = ciba_graph.get_options()
14
- tools = ciba_graph.get_tools()
15
- tool_definition = get_tool_definition(state, tools)
16
-
17
- if not tool_definition:
18
- return Command(resume=True)
19
-
20
- graph = ciba_graph.get_graph()
21
- metadata, tool = tool_definition["metadata"], tool_definition["tool"]
22
- ciba_options = metadata.options
23
-
24
- langgraph = get_client(url=os.getenv("LANGGRAPH_API_URL", "http://localhost:54367"))
25
-
26
- # Check if CIBA Poller Graph exists
27
- search_result = await langgraph.assistants.search(graph_id=Auth0Graphs.CIBA_POLLER.value)
28
- if not search_result:
29
- raise ValueError(
30
- f"[{Auth0Nodes.AUTH0_CIBA}] \"{Auth0Graphs.CIBA_POLLER}\" does not exist. Make sure to register the graph in your \"langgraph.json\"."
31
- )
32
-
33
- if ciba_options["on_approve_go_to"] not in graph.nodes:
34
- raise ValueError(f"[{Auth0Nodes.AUTH0_CIBA}] \"{ciba_options["on_approve_go_to"]}\" is not a valid node.")
35
-
36
- if ciba_options["on_reject_go_to"] not in graph.nodes:
37
- raise ValueError(f"[{Auth0Nodes.AUTH0_CIBA}] \"{ciba_options["on_reject_go_to"]}\" is not a valid node.")
38
-
39
- scheduler = ciba_params.config["scheduler"]
40
- on_resume_invoke = ciba_params.config["on_resume_invoke"]
41
- audience = ciba_params.audience
42
-
43
- if not scheduler:
44
- raise ValueError(f"[{Auth0Nodes.AUTH0_CIBA}] \"scheduler\" must be a \"function\" or a \"string\".")
45
-
46
- if not on_resume_invoke:
47
- raise ValueError(f"[{Auth0Nodes.AUTH0_CIBA}] \"on_resume_invoke\" must be defined.")
48
-
49
- user_id = config.get("configurable", {}).get("user_id")
50
- thread_id = config.get("metadata", {}).get("thread_id")
51
-
52
- ciba_response = await CIBAAuthorizer.start(
53
- {
54
- "user_id": user_id,
55
- "scope": ciba_options["scope"] or "openid",
56
- "audience": audience,
57
- "binding_message": ciba_options["binding_message"],
58
- },
59
- ciba_graph.get_authorizer_params(),
60
- tool["args"],
61
- )
62
-
63
- scheduler_params = {
64
- "tool_id": tool["id"],
65
- "user_id": user_id,
66
- "ciba_graph_id": Auth0Graphs.CIBA_POLLER.value,
67
- "thread_id": thread_id,
68
- "ciba_response": ciba_response,
69
- "on_resume_invoke": on_resume_invoke,
70
- }
71
-
72
- if callable(scheduler):
73
- # Use Custom Scheduler
74
- await scheduler(scheduler_params)
75
- elif isinstance(scheduler, str):
76
- # Use Langgraph SDK to schedule the task
77
- await langgraph.crons.create_for_thread(
78
- thread_id,
79
- scheduler_params["ciba_graph_id"],
80
- schedule="*/1 * * * *", # Default to every minute
81
- input=scheduler_params,
82
- )
83
-
84
- print("CIBA Task Scheduled")
85
- except Exception as e:
86
- print(e)
87
- state["auth0"] = {"error": str(e)}
88
-
89
- return state
90
-
91
- return handler
@@ -1,50 +0,0 @@
1
- from typing import Awaitable, Callable
2
- from langchain_core.messages import ToolMessage, AIMessage, ToolCall
3
- from langgraph.types import interrupt, Command
4
- from .types import ICIBAGraph, BaseState
5
- from .utils import get_tool_definition
6
- from auth0_ai.authorizers.ciba_authorizer import CibaAuthorizerCheckResponse
7
-
8
- def initialize_hitl(ciba_graph: ICIBAGraph) -> Callable[[BaseState], Awaitable[Command]]:
9
- async def handler(state: BaseState) -> Command:
10
- tools = ciba_graph.get_tools()
11
- tool_definition = get_tool_definition(state, tools)
12
-
13
- # if no tool calls, resume
14
- if not tool_definition:
15
- return Command(resume=True)
16
-
17
- # wait for user approval
18
- human_review = interrupt("A push notification has been sent to your device.")
19
-
20
- metadata, tool, message = tool_definition["metadata"], tool_definition["tool"], tool_definition["message"]
21
-
22
- if human_review["status"] == CibaAuthorizerCheckResponse.APPROVED.value:
23
- updated_message = AIMessage(
24
- id=message.id,
25
- content="The user has approved the transaction",
26
- tool_calls=[
27
- ToolCall(
28
- name=tool["name"],
29
- args=tool["args"],
30
- id=tool["id"],
31
- )
32
- ],
33
- )
34
-
35
- return Command(
36
- goto=metadata.options["on_approve_go_to"],
37
- update={"messages": [updated_message]},
38
- )
39
- else:
40
- tool_message = ToolMessage(
41
- name=tool["name"],
42
- content="The user has rejected the transaction.",
43
- tool_call_id=tool["id"],
44
- )
45
- return Command(
46
- goto=metadata.options["on_reject_go_to"],
47
- update={"messages": [tool_message]},
48
- )
49
-
50
- return handler
@@ -1,115 +0,0 @@
1
- from typing import Optional, List, Callable, Union, Awaitable, TypedDict
2
- from abc import ABC, abstractmethod
3
- from langgraph.graph import StateGraph
4
- from langchain_core.messages import AIMessage, ToolMessage
5
- from auth0_ai.authorizers.types import AuthorizerParams
6
- from auth0_ai.authorizers.ciba_authorizer import AuthorizeResponse
7
-
8
- class Auth0State(TypedDict):
9
- error: str
10
-
11
- class BaseState(TypedDict):
12
- task_id: str
13
- messages: List[Union[AIMessage, ToolMessage]]
14
- auth0: Optional[Auth0State] = None
15
-
16
- class SchedulerParams:
17
- def __init__(
18
- self,
19
- user_id: str,
20
- thread_id: str,
21
- ciba_graph_id: str,
22
- ciba_response: AuthorizeResponse,
23
- tool_id: Optional[str] = None,
24
- on_resume_invoke: str = "",
25
- ):
26
- self.user_id = user_id
27
- self.thread_id = thread_id
28
- self.tool_id = tool_id
29
- self.on_resume_invoke = on_resume_invoke
30
- self.ciba_graph_id = ciba_graph_id
31
- self.ciba_response = ciba_response
32
-
33
- class CIBAOptions():
34
- """
35
- The CIBA options.
36
-
37
- Attributes:
38
- binding_message (Union[str, Callable[..., Awaitable[str]]]): A human-readable string to display to the user, or a function that resolves it.
39
- scope (Optional[str]): Space-separated list of OIDC and custom API scopes.
40
- on_approve_go_to (Optional[str]): A node name to redirect the flow after user approval.
41
- on_reject_go_to (Optional[str]): A node name to redirect the flow after user rejection.
42
- audience (Optional[str]): Unique identifier of the audience for an issued token.
43
- request_expiry (Optional[int]): To configure a custom expiry time in seconds for CIBA request, pass a number between 1 and 300.
44
- """
45
- def __init__(
46
- self,
47
- binding_message: Union[str, Callable[..., Awaitable[str]]],
48
- scope: Optional[str] = None,
49
- on_approve_go_to: Optional[str] = None,
50
- on_reject_go_to: Optional[str] = None,
51
- audience: Optional[str] = None,
52
- request_expiry: Optional[int] = None,
53
- ):
54
- self.binding_message = binding_message
55
- self.scope = scope
56
- self.on_approve_go_to = on_approve_go_to
57
- self.on_reject_go_to = on_reject_go_to
58
- self.audience = audience
59
- self.request_expiry = request_expiry
60
-
61
- class ProtectedTool():
62
- def __init__(self, tool_name: str, options: CIBAOptions):
63
- self.tool_name = tool_name
64
- self.options = options
65
-
66
- class CIBAGraphOptionsConfig:
67
- def __init__(self, on_resume_invoke: str, scheduler: Union[str, Callable[[SchedulerParams], Awaitable[None]]]):
68
- self.on_resume_invoke = on_resume_invoke
69
- self.scheduler = scheduler
70
-
71
- class CIBAGraphOptions():
72
- """
73
- The base CIBA options.
74
-
75
- Attributes:
76
- config (CIBAGraphOptionsConfig): Configuration options.
77
- scope (Optional[str]): Space-separated list of OIDC and custom API scopes.
78
- on_approve_go_to (Optional[str]): A node name to redirect the flow after user approval.
79
- on_reject_go_to (Optional[str]): A node name to redirect the flow after user rejection.
80
- audience (Optional[str]): Unique identifier of the audience for an issued token.
81
- request_expiry (Optional[int]): To configure a custom expiry time in seconds for CIBA request, pass a number between 1 and 300.
82
- """
83
- def __init__(
84
- self,
85
- config: CIBAGraphOptionsConfig,
86
- scope: Optional[str] = None,
87
- on_approve_go_to: Optional[str] = None,
88
- on_reject_go_to: Optional[str] = None,
89
- audience: Optional[str] = None,
90
- request_expiry: Optional[int] = None,
91
-
92
- ):
93
- self.config = config
94
- self.scope = scope
95
- self.on_approve_go_to = on_approve_go_to
96
- self.on_reject_go_to = on_reject_go_to
97
- self.audience = audience
98
- self.request_expiry = request_expiry
99
-
100
- class ICIBAGraph(ABC):
101
- @abstractmethod
102
- def get_tools(self) -> List[ProtectedTool]:
103
- pass
104
-
105
- @abstractmethod
106
- def get_graph(self) -> StateGraph:
107
- pass
108
-
109
- @abstractmethod
110
- def get_authorizer_params(self) -> Optional[AuthorizerParams]:
111
- pass
112
-
113
- @abstractmethod
114
- def get_options(self) -> Optional[CIBAGraphOptions]:
115
- pass
@@ -1,17 +0,0 @@
1
- from typing import Optional, List
2
- from .types import ProtectedTool, BaseState
3
-
4
- def get_tool_definition(state: BaseState, tools: List[ProtectedTool]) -> Optional[dict]:
5
- message = state["messages"][-1]
6
-
7
- if not hasattr(message, "tool_calls") or not message.tool_calls:
8
- return None
9
-
10
- tool_calls = message.tool_calls
11
- tool = tool_calls[-1]
12
- metadata = next((t for t in tools if t.tool_name == tool["name"]), None)
13
-
14
- if not metadata:
15
- return None
16
-
17
- return {"metadata": metadata, "tool": tool, "message": message}
@@ -1,105 +0,0 @@
1
- import os
2
- from typing import Awaitable, Callable, Optional, TypedDict, Union
3
-
4
- from auth0_ai.authorizers.ciba_authorizer import (
5
- AuthorizeResponse,
6
- CIBAAuthorizer,
7
- CibaAuthorizerCheckResponse,
8
- )
9
- from auth0_ai.credentials import Credentials
10
- from auth0_ai.token_response import TokenResponse
11
- from langgraph.graph import END, START, StateGraph
12
- from langgraph_sdk import get_client
13
- from langgraph_sdk.schema import Command
14
-
15
- from auth0_ai_langchain.ciba.types import Auth0Graphs
16
-
17
-
18
- class State(TypedDict):
19
- ciba_response: AuthorizeResponse
20
- on_resume_invoke: str
21
- thread_id: str
22
- user_id: str
23
-
24
- # Internal
25
- task_id: str
26
- tool_id: str
27
- status: CibaAuthorizerCheckResponse
28
- token_response: Optional[TokenResponse]
29
-
30
-
31
- def ciba_poller_graph(on_stop_scheduler: Union[str, Callable[[State], Awaitable[None]]]):
32
- """
33
- A LangGraph graph to monitor the status of a CIBA transaction.
34
-
35
- Attributes:
36
- on_stop_scheduler (Union[str, Callable[[State], Awaitable[None]]]): A graph name to redirect the flow, or a function to execute when the CIBA transaction expires.
37
- """
38
- async def check_status(state: State):
39
- try:
40
- res = await CIBAAuthorizer.check(state["ciba_response"]["auth_req_id"])
41
- state["token_response"] = res.get("token")
42
- state["status"] = res.get("status")
43
- except Exception as e:
44
- print(f"Error in check_status: {e}")
45
- return state
46
-
47
- async def stop_scheduler(state: State):
48
- try:
49
- if isinstance(on_stop_scheduler, str):
50
- langgraph = get_client(url=os.getenv(
51
- "LANGGRAPH_API_URL", "http://localhost:54367"))
52
- await langgraph.crons.create_for_thread(state.thread_id, Auth0Graphs.CIBA_POLLER.value)
53
- elif callable(on_stop_scheduler):
54
- await on_stop_scheduler(state)
55
- except Exception as e:
56
- print(f"Error in stop_scheduler: {e}")
57
- return state
58
-
59
- async def resume_agent(state: State):
60
- langgraph = get_client(url=os.getenv(
61
- "LANGGRAPH_API_URL", "http://localhost:54367"))
62
- _credentials: Credentials = None
63
-
64
- try:
65
- if state["status"] == CibaAuthorizerCheckResponse.APPROVED:
66
- _credentials = {
67
- "access_token": {
68
- "type": state["token_response"].get("token_type", "Bearer"),
69
- "value": state["token_response"].get("access_token"),
70
- }
71
- }
72
-
73
- await langgraph.runs.wait(
74
- state["thread_id"],
75
- state["on_resume_invoke"],
76
- config={
77
- # this is only for this run / thread_id
78
- "configurable": {"_credentials": _credentials}
79
- },
80
- command=Command(resume={"status": state["status"]})
81
- )
82
- except Exception as e:
83
- print(f"Error in resume_agent: {e}")
84
-
85
- return state
86
-
87
- async def should_continue(state: State):
88
- status = state.get("status")
89
- if status == CibaAuthorizerCheckResponse.PENDING:
90
- return END
91
- elif status == CibaAuthorizerCheckResponse.EXPIRED:
92
- return "stop_scheduler"
93
- elif status in [CibaAuthorizerCheckResponse.APPROVED, CibaAuthorizerCheckResponse.REJECTED]:
94
- return "resume_agent"
95
- return END
96
-
97
- state_graph = StateGraph(State)
98
- state_graph.add_node("check_status", check_status)
99
- state_graph.add_node("stop_scheduler", stop_scheduler)
100
- state_graph.add_node("resume_agent", resume_agent)
101
- state_graph.add_edge(START, "check_status")
102
- state_graph.add_edge("resume_agent", "stop_scheduler")
103
- state_graph.add_conditional_edges("check_status", should_continue)
104
-
105
- return state_graph
@@ -1,8 +0,0 @@
1
- from enum import Enum
2
-
3
- class Auth0Nodes(Enum):
4
- AUTH0_CIBA_HITL = "AUTH0_CIBA_HITL"
5
- AUTH0_CIBA = "AUTH0_CIBA"
6
-
7
- class Auth0Graphs(Enum):
8
- CIBA_POLLER = "AUTH0_CIBA_POLLER"
@@ -1,3 +0,0 @@
1
- from auth0_ai.authorizers.fga_authorizer import FGAAuthorizer, FGAAuthorizerOptions
2
-
3
- __all__ = ["FGAAuthorizer", "FGAAuthorizerOptions"]
@@ -1,19 +0,0 @@
1
- auth0_ai_langchain/FGARetriever.py,sha256=6nQXRkbDLHZt9zYZJsS5iQljrogQVLW0aVwDIf6Mpac,6002
2
- auth0_ai_langchain/__init__.py,sha256=I331Kz-q97ZU7TfXaOR5UBbJamGEJ15twbf2HP1iCHs,67
3
- auth0_ai_langchain/auth0_ai.py,sha256=8NUV_80SxR8qQt_3RQGf0Oga178kChuROHuhz7rfOyU,1919
4
- auth0_ai_langchain/ciba/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- auth0_ai_langchain/ciba/ciba_graph/ciba_graph.py,sha256=Wi7qXSMzvcfqdO8WsUJejmQzOVM469TFJCkH7eRlaR8,4115
6
- auth0_ai_langchain/ciba/ciba_graph/initialize_ciba.py,sha256=a41KedBzxfLqG2AhvkFnemcEqWwQWVh1r-Oro-qJX-M,3752
7
- auth0_ai_langchain/ciba/ciba_graph/initialize_hitl.py,sha256=CR3jMolZYYOBHx1AXb6yERBaZThMg677qGGo_vRy6I8,1901
8
- auth0_ai_langchain/ciba/ciba_graph/types.py,sha256=NZS99vPOgJRc2O7mO5MWj6nAa4RSG1R5oSvzYhkz0RA,4234
9
- auth0_ai_langchain/ciba/ciba_graph/utils.py,sha256=ZPAh0Gs7Hj59_xngg8M7yx1v52dTn2pNpMNRpFKCSII,560
10
- auth0_ai_langchain/ciba/ciba_poller_graph.py,sha256=rjlxsTheJrN2J6E5BCE9aVyDO8V7UFhdKyX4KT_hB8k,3828
11
- auth0_ai_langchain/ciba/types.py,sha256=gybqYEprklZwcMBgaWFooBsJ1GcNUK8ZWRvAX5PZWdE,177
12
- auth0_ai_langchain/federated_connections/__init__.py,sha256=nWA0eZj88nkamiZz8Wx-KVZ9r1faNNjWBdQAYfrPO1A,480
13
- auth0_ai_langchain/federated_connections/federated_connection_authorizer.py,sha256=ZLF4p7fPTrODOeHWIShfBTsReSy27u73rIp0L_Umhjg,2218
14
- auth0_ai_langchain/fga/fga_authorizer.py,sha256=uDaGDSXaxQd1X-2w2zTvnfizMB-DtQ-1G6SIaDNBrho,137
15
- auth0_ai_langchain/utils/interrupt.py,sha256=HHmlwKwQR_sx8dsV_cZTCYLDo6n6JuiPyLJxbhiI84w,449
16
- auth0_ai_langchain-0.2.0.dist-info/LICENSE,sha256=Lu_2YH0oK8b_VVisAhNQ2WIdtwY8pSU2PLbll-y6Cj8,9792
17
- auth0_ai_langchain-0.2.0.dist-info/METADATA,sha256=Gn337MU0yUTAFP_DUyRhkyxj6j5jw4frYmVpJmPMqis,7793
18
- auth0_ai_langchain-0.2.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
19
- auth0_ai_langchain-0.2.0.dist-info/RECORD,,