auth0-ai-langchain 0.2.0__py3-none-any.whl → 1.0.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auth0-ai-langchain might be problematic. Click here for more details.

@@ -32,7 +32,7 @@ class FGARetriever(BaseRetriever):
32
32
  Args:
33
33
  retriever (BaseRetriever): The retriever used to fetch documents.
34
34
  build_query (Callable[[Document], ClientBatchCheckItem]): Function to convert documents into FGA queries.
35
- fga_configuration (Optional[ClientConfiguration]): Configuration for the OpenFGA client. If not provided, defaults to environment variables.
35
+ fga_configuration (ClientConfiguration, optional): Configuration for the OpenFGA client. If not provided, defaults to environment variables.
36
36
  """
37
37
  super().__init__()
38
38
  self._retriever = retriever
@@ -95,7 +95,7 @@ class FGARetriever(BaseRetriever):
95
95
 
96
96
  Args:
97
97
  query (str): The query for retrieving documents.
98
- run_manager (Optional[object]): Optional manager for tracking runs.
98
+ run_manager (object, optional): Optional manager for tracking runs.
99
99
 
100
100
  Returns:
101
101
  List[Document]: Filtered and relevant documents.
@@ -148,7 +148,7 @@ class FGARetriever(BaseRetriever):
148
148
 
149
149
  Args:
150
150
  query (str): The query for retrieving documents.
151
- run_manager (Optional[object]): Optional manager for tracking runs.
151
+ run_manager (object, optional): Optional manager for tracking runs.
152
152
 
153
153
  Returns:
154
154
  List[Document]: Filtered and relevant documents.
@@ -1,43 +1,113 @@
1
1
  from typing import Callable, Optional
2
- from langchain_core.runnables.config import RunnableConfig
3
2
  from langchain_core.tools import BaseTool
4
- from auth0_ai.credentials import Credential
5
- from auth0_ai.authorizers.types import AuthorizerParams
6
- from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerParams
7
- from .federated_connections.federated_connection_authorizer import FederatedConnectionAuthorizer
8
- from .ciba.ciba_graph.ciba_graph import CIBAGraph
9
- from .ciba.ciba_graph.types import CIBAGraphOptions
10
-
11
- def get_access_token(config: RunnableConfig) -> Credential:
12
- """
13
- Fetch the access token obtained during the CIBA flow.
3
+ from auth0_ai.authorizers.ciba import CIBAAuthorizerParams
4
+ from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerParams
5
+ from auth0_ai.authorizers.types import Auth0ClientParams
6
+ from auth0_ai_langchain.ciba.ciba_authorizer import CIBAAuthorizer
7
+ from auth0_ai_langchain.federated_connections.federated_connection_authorizer import FederatedConnectionAuthorizer
8
+
14
9
 
15
- Attributes:
16
- config(RunnableConfig): LangGraph runnable configuration instance.
10
+ class Auth0AI:
11
+ """Provides decorators to secure LangChain tools using Auth0 authorization flows.
17
12
  """
18
- return config.get("configurable", {}).get("_credentials", {}).get("access_token")
19
13
 
20
- class Auth0AI():
21
- def __init__(self, config: Optional[AuthorizerParams] = None):
22
- self._graph: Optional[CIBAGraph] = None
23
- self.config = config
14
+ def __init__(self, auth0: Optional[Auth0ClientParams] = None):
15
+ """Initializes the Auth0AI instance.
24
16
 
25
- def with_async_user_confirmation(self, **options: CIBAGraphOptions) -> CIBAGraph:
17
+ Args:
18
+ auth0 (Optional[Auth0ClientParams]): Parameters for the Auth0 client.
19
+ If not provided, values will be automatically read from environment
20
+ variables: `AUTH0_DOMAIN`, `AUTH0_CLIENT_ID`, and `AUTH0_CLIENT_SECRET`.
26
21
  """
27
- Initializes and registers a state graph for conditional trade operations using CIBA.
22
+ self.auth0 = auth0
28
23
 
29
- Attributes:
30
- options (Optional[CIBAGraphOptions]): The base CIBA options.
31
- """
32
- self._graph = CIBAGraph(CIBAGraphOptions(**options), self.config)
33
- return self._graph
34
-
35
- def with_federated_connection(self, **options: FederatedConnectionAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
24
+ def with_async_user_confirmation(self, **params: CIBAAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
25
+ """Protects a tool with the CIBA (Client-Initiated Backchannel Authentication) flow.
26
+
27
+ Requires user confirmation via a second device (e.g., phone)
28
+ before allowing the tool to execute.
29
+
30
+ Args:
31
+ **params: Parameters defined in `CIBAAuthorizerParams`.
32
+
33
+ Returns:
34
+ Callable[[BaseTool], BaseTool]: A decorator to wrap a LangChain tool.
35
+
36
+ Example:
37
+ ```python
38
+ import os
39
+ from auth0_ai_langchain.auth0_ai import Auth0AI
40
+ from auth0_ai_langchain.ciba import get_ciba_credentials
41
+ from langchain_core.runnables import ensure_config
42
+ from langchain_core.tools import StructuredTool
43
+
44
+ auth0_ai = Auth0AI()
45
+
46
+ with_async_user_confirmation = auth0_ai.with_async_user_confirmation(
47
+ scopes=["stock:trade"],
48
+ audience=os.getenv("AUDIENCE"),
49
+ binding_message=lambda ticker, qty: f"Authorize the purchase of {qty} {ticker}",
50
+ user_id=lambda *_, **__: ensure_config().get("configurable", {}).get("user_id")
51
+ )
52
+
53
+ def tool_function(ticker: str, qty: int) -> str:
54
+ credentials = get_ciba_credentials()
55
+ headers = {
56
+ "Authorization": f"{credentials['token_type']} {credentials['access_token']}",
57
+ # ...
58
+ }
59
+ # Call API
60
+
61
+ trade_tool = with_async_user_confirmation(
62
+ StructuredTool(
63
+ name="trade_tool",
64
+ description="Use this function to trade a stock",
65
+ func=tool_function,
66
+ )
67
+ )
68
+ ```
36
69
  """
37
- Protects a tool execution with the Federated Connection authorizer.
70
+ authorizer = CIBAAuthorizer(CIBAAuthorizerParams(**params), self.auth0)
71
+ return authorizer.authorizer()
72
+
73
+ def with_federated_connection(self, **params: FederatedConnectionAuthorizerParams) -> Callable[[BaseTool], BaseTool]:
74
+ """Enables a tool to obtain an access token from a federated identity provider (e.g., Google, Azure AD).
75
+
76
+ The token can then be used within the tool to call third-party APIs on behalf of the user.
77
+
78
+ Args:
79
+ **params: Parameters defined in `FederatedConnectionAuthorizerParams`.
80
+
81
+ Returns:
82
+ Callable[[BaseTool], BaseTool]: A decorator to wrap a LangChain tool.
83
+
84
+ Example:
85
+ ```python
86
+ from auth0_ai_langchain.auth0_ai import Auth0AI
87
+ from auth0_ai_langchain.federated_connections import get_credentials_for_connection
88
+ from langchain_core.tools import StructuredTool
89
+ from datetime import datetime
90
+
91
+ auth0_ai = Auth0AI()
92
+
93
+ with_google_calendar_access = auth0_ai.with_federated_connection(
94
+ connection="google-oauth2",
95
+ scopes=["https://www.googleapis.com/auth/calendar.freebusy"]
96
+ )
97
+
98
+ def tool_function(date: datetime):
99
+ credentials = get_credentials_for_connection()
100
+ # Call Google API using credentials["access_token"]
38
101
 
39
- Attributes:
40
- options (FederatedConnectionAuthorizerParams): The Federated Connections authorizer options.
102
+ check_calendar_tool = with_google_calendar_access(
103
+ StructuredTool(
104
+ name="check_user_calendar",
105
+ description="Use this function to check if the user is available on a certain date and time",
106
+ func=tool_function,
107
+ )
108
+ )
109
+ ```
41
110
  """
42
- authorizer = FederatedConnectionAuthorizer(FederatedConnectionAuthorizerParams(**options), self.config)
111
+ authorizer = FederatedConnectionAuthorizer(
112
+ FederatedConnectionAuthorizerParams(**params), self.auth0)
43
113
  return authorizer.authorizer()
@@ -0,0 +1,3 @@
1
+ from auth0_ai.authorizers.ciba.ciba_authorizer_base import get_ciba_credentials as get_ciba_credentials
2
+ from auth0_ai_langchain.ciba.ciba_authorizer import CIBAAuthorizer as CIBAAuthorizer
3
+ from auth0_ai_langchain.ciba.graph_resumer import GraphResumer as GraphResumer
@@ -0,0 +1,17 @@
1
+ from abc import ABC
2
+ from typing import Union
3
+ from auth0_ai.authorizers.ciba import CIBAAuthorizerBase
4
+ from auth0_ai.interrupts.ciba_interrupts import AuthorizationPendingInterrupt, AuthorizationPollingInterrupt
5
+ from auth0_ai_langchain.utils.interrupt import to_graph_interrupt
6
+ from auth0_ai_langchain.utils.tool_wrapper import tool_wrapper
7
+ from langchain_core.tools import BaseTool
8
+
9
+ class CIBAAuthorizer(CIBAAuthorizerBase, ABC):
10
+ def _handle_authorization_interrupts(self, err: Union[AuthorizationPendingInterrupt, AuthorizationPollingInterrupt]) -> None:
11
+ raise to_graph_interrupt(err)
12
+
13
+ def authorizer(self):
14
+ def wrap_tool(tool: BaseTool) -> BaseTool:
15
+ return tool_wrapper(tool, self.protect)
16
+
17
+ return wrap_tool
@@ -0,0 +1,154 @@
1
+ import asyncio
2
+ from threading import Event
3
+ from typing import Callable, Optional, Dict, Any, List, TypedDict
4
+ from auth0_ai.authorizers.ciba import CIBAAuthorizationRequest
5
+ from auth0_ai.interrupts.ciba_interrupts import CIBAInterrupt, AuthorizationPendingInterrupt, AuthorizationPollingInterrupt
6
+ from auth0_ai_langchain.utils.interrupt import get_auth0_interrupts
7
+ from langgraph_sdk.client import LangGraphClient
8
+ from langgraph_sdk.schema import Thread, Interrupt
9
+
10
+ class WatchedThread(TypedDict):
11
+ thread_id: str
12
+ assistant_id: str
13
+ interruption_id: str
14
+ auth_request: CIBAAuthorizationRequest
15
+ config: Dict[str, Any]
16
+ last_run: float
17
+
18
+ class GraphResumerFilters(TypedDict):
19
+ graph_id: str
20
+
21
+ class GraphResumer:
22
+ def __init__(self, lang_graph: LangGraphClient, filters: Optional[GraphResumerFilters] = None):
23
+ self.lang_graph = lang_graph
24
+ self.filters = filters or {}
25
+ self.map: Dict[str, WatchedThread] = {}
26
+ self._stop_event = Event()
27
+ self._loop_task: Optional[asyncio.Task] = None
28
+
29
+ # Event callbacks
30
+ self._resume_callbacks: List[Callable[[WatchedThread], None]] = []
31
+ self._error_callbacks: List[Callable[[Exception], None]] = []
32
+
33
+ # Public API to register event callbacks
34
+ def on_resume(self, callback: Callable[[WatchedThread], None]) -> "GraphResumer":
35
+ self._resume_callbacks.append(callback)
36
+ return self
37
+
38
+ def on_error(self, callback: Callable[[Exception], None]) -> "GraphResumer":
39
+ self._error_callbacks.append(callback)
40
+ return self
41
+
42
+ def _emit_resume(self, thread: WatchedThread) -> None:
43
+ for callback in self._resume_callbacks:
44
+ callback(thread)
45
+
46
+ def _emit_error(self, error: Exception) -> None:
47
+ for callback in self._error_callbacks:
48
+ callback(error)
49
+
50
+ async def _get_all_interrupted_threads(self) -> List[Thread]:
51
+ interrupted_threads: List[Thread] = []
52
+ offset = 0
53
+
54
+ while True:
55
+ page = await self.lang_graph.threads.search(
56
+ status="interrupted",
57
+ limit=100,
58
+ offset=offset,
59
+ metadata={"graph_id": self.filters["graph_id"]} if "graph_id" in self.filters else None
60
+ )
61
+
62
+ if not page:
63
+ break
64
+
65
+ for t in page:
66
+ interrupt = self._get_first_interrupt(t)
67
+ if interrupt and CIBAInterrupt.is_interrupt(interrupt["value"]) and CIBAInterrupt.has_request_data(interrupt["value"]):
68
+ interrupted_threads.append(t)
69
+
70
+ offset += len(page)
71
+ if len(page) < 100:
72
+ break
73
+
74
+ return interrupted_threads
75
+
76
+ def _get_first_interrupt(self, thread: Thread) -> Optional[Interrupt]:
77
+ interrupts = thread["interrupts"]
78
+ if interrupts:
79
+ values = list(interrupts.values())
80
+ if values and values[0]:
81
+ return values[0][0]
82
+ return None
83
+
84
+ def _get_hash_map_id(self, thread: Thread) -> str:
85
+ return f"{thread['thread_id']}:{next(iter(thread['interrupts']))}"
86
+
87
+ async def _resume_thread(self, t: WatchedThread):
88
+ self._emit_resume(t)
89
+
90
+ await self.lang_graph.runs.wait(t["thread_id"], t["assistant_id"], config=t["config"])
91
+
92
+ t["last_run"] = asyncio.get_event_loop().time() * 1000
93
+
94
+ async def loop(self):
95
+ all_threads = await self._get_all_interrupted_threads()
96
+
97
+ # Remove old interrupted threads
98
+ active_keys = {self._get_hash_map_id(t) for t in all_threads}
99
+
100
+ for key in list(self.map.keys()):
101
+ if key not in active_keys:
102
+ del self.map[key]
103
+
104
+ # Add new interrupted threads
105
+ for thread in all_threads:
106
+ interrupt = next(
107
+ (i for i in get_auth0_interrupts(thread)
108
+ if AuthorizationPendingInterrupt.is_interrupt(i["value"])
109
+ or AuthorizationPollingInterrupt.is_interrupt(i["value"])),
110
+ None
111
+ )
112
+
113
+ if not interrupt or not interrupt["value"].get("_request"):
114
+ continue
115
+
116
+ key = self._get_hash_map_id(thread)
117
+ if key not in self.map:
118
+ self.map[key] = {
119
+ "thread_id": thread["thread_id"],
120
+ "assistant_id": thread["metadata"].get("graph_id"),
121
+ "config": getattr(thread, "config", {}),
122
+ "interruption_id": next(iter(thread["interrupts"])),
123
+ "auth_request": interrupt["value"]["_request"],
124
+ }
125
+
126
+ threads_to_resume = [
127
+ t for t in self.map.values()
128
+ if "last_run" not in t or (t["last_run"] + t["auth_request"]["interval"] * 1000 < asyncio.get_event_loop().time() * 1000)
129
+ ]
130
+
131
+ await asyncio.gather(*[
132
+ self._resume_thread(t) for t in threads_to_resume
133
+ ])
134
+
135
+ def start(self):
136
+ if self._loop_task and not self._loop_task.done():
137
+ return
138
+
139
+ self._stop_event.clear()
140
+
141
+ async def _run_loop():
142
+ while not self._stop_event.is_set():
143
+ try:
144
+ await self.loop()
145
+ except Exception as e:
146
+ self._emit_error(e)
147
+ await asyncio.sleep(5)
148
+
149
+ self._loop_task = asyncio.create_task(_run_loop())
150
+
151
+ def stop(self):
152
+ self._stop_event.set()
153
+ if self._loop_task:
154
+ self._loop_task.cancel()
@@ -1,4 +1,10 @@
1
- from auth0_ai.interrupts.federated_connection_interrupt import FederatedConnectionError as FederatedConnectionError
2
- from auth0_ai.interrupts.federated_connection_interrupt import FederatedConnectionInterrupt as FederatedConnectionInterrupt
3
- from auth0_ai.authorizers.federated_connection_authorizer import get_access_token_for_connection as get_access_token_for_connection
1
+ from auth0_ai.interrupts.federated_connection_interrupt import (
2
+ FederatedConnectionError as FederatedConnectionError,
3
+ FederatedConnectionInterrupt as FederatedConnectionInterrupt
4
+ )
5
+
6
+ from auth0_ai.authorizers.federated_connection_authorizer import (
7
+ get_credentials_for_connection as get_credentials_for_connection,
8
+ get_access_token_for_connection as get_access_token_for_connection
9
+ )
4
10
  from .federated_connection_authorizer import FederatedConnectionAuthorizer as FederatedConnectionAuthorizer
@@ -1,52 +1,33 @@
1
1
  import copy
2
2
  from abc import ABC
3
3
  from auth0_ai.authorizers.federated_connection_authorizer import FederatedConnectionAuthorizerBase, FederatedConnectionAuthorizerParams
4
- from auth0_ai.authorizers.types import AuthorizerParams
4
+ from auth0_ai.authorizers.types import Auth0ClientParams
5
5
  from auth0_ai.interrupts.federated_connection_interrupt import FederatedConnectionInterrupt
6
- from langchain_core.tools import BaseTool, tool
6
+ from auth0_ai_langchain.utils.interrupt import to_graph_interrupt
7
+ from auth0_ai_langchain.utils.tool_wrapper import tool_wrapper
8
+ from langchain_core.tools import BaseTool
7
9
  from langchain_core.runnables import ensure_config
8
- from ..utils.interrupt import to_graph_interrupt
9
10
 
10
- async def get_refresh_token(*_args, **_kwargs) -> str | None:
11
+ async def default_get_refresh_token(*_, **__) -> str | None:
11
12
  return ensure_config().get("configurable", {}).get("_credentials", {}).get("refresh_token")
12
13
 
13
14
  class FederatedConnectionAuthorizer(FederatedConnectionAuthorizerBase, ABC):
14
15
  def __init__(
15
16
  self,
16
- options: FederatedConnectionAuthorizerParams,
17
- config: AuthorizerParams = None,
17
+ params: FederatedConnectionAuthorizerParams,
18
+ auth0: Auth0ClientParams = None,
18
19
  ):
19
- if options.refresh_token.value is None:
20
- options = copy.copy(options)
21
- options.refresh_token.value = get_refresh_token
20
+ if params.refresh_token.value is None:
21
+ params = copy.copy(params)
22
+ params.refresh_token.value = default_get_refresh_token
22
23
 
23
- super().__init__(options, config)
24
+ super().__init__(params, auth0)
24
25
 
25
26
  def _handle_authorization_interrupts(self, err: FederatedConnectionInterrupt) -> None:
26
27
  raise to_graph_interrupt(err)
27
28
 
28
29
  def authorizer(self):
29
- def wrapped_tool(t: BaseTool) -> BaseTool:
30
- async def execute_fn(*_args, **kwargs):
31
- return await t.ainvoke(input=kwargs)
32
-
33
- tool_fn = self.protect(
34
- lambda *_args, **_kwargs: {
35
- "thread_id": ensure_config().get("configurable", {}).get("thread_id"),
36
- "checkpoint_ns": ensure_config().get("configurable", {}).get("checkpoint_ns"),
37
- "run_id": ensure_config().get("configurable", {}).get("run_id"),
38
- "tool_call_id": ensure_config().get("configurable", {}).get("tool_call_id"), # TODO: review this
39
- },
40
- execute_fn
41
- )
42
- tool_fn.__name__ = t.name
43
-
44
- return tool(
45
- tool_fn,
46
- description=t.description,
47
- return_direct=t.return_direct,
48
- args_schema=t.args_schema,
49
- response_format=t.response_format,
50
- )
30
+ def wrap_tool(tool: BaseTool) -> BaseTool:
31
+ return tool_wrapper(tool, self.protect)
51
32
 
52
- return wrapped_tool
33
+ return wrap_tool
@@ -0,0 +1,4 @@
1
+ from auth0_ai.authorizers.fga_authorizer import (
2
+ FGAAuthorizer as FGAAuthorizer,
3
+ FGAAuthorizerOptions as FGAAuthorizerOptions
4
+ )
@@ -1,6 +1,9 @@
1
+ from typing import List
1
2
  from auth0_ai.interrupts.auth0_interrupt import Auth0Interrupt
2
3
  from langgraph.errors import GraphInterrupt
3
4
  from langgraph.types import Interrupt
5
+ from langgraph_sdk.schema import Thread
6
+
4
7
 
5
8
  def to_graph_interrupt(interrupt: Auth0Interrupt) -> GraphInterrupt:
6
9
  return GraphInterrupt([
@@ -8,6 +11,20 @@ def to_graph_interrupt(interrupt: Auth0Interrupt) -> GraphInterrupt:
8
11
  value=interrupt.to_json(),
9
12
  when="during",
10
13
  resumable=True,
11
- ns=[f"auth0AI:{interrupt.__class__.__name__}:{interrupt.code}"]
14
+ ns=[f"auth0AI:{interrupt.name}:{interrupt.code}"]
12
15
  )
13
16
  ])
17
+
18
+
19
+ def get_auth0_interrupts(thread: Thread) -> List[Interrupt]:
20
+ result = []
21
+
22
+ if "interrupts" not in thread:
23
+ return result
24
+
25
+ for interrupt_list in thread["interrupts"].values():
26
+ for interrupt in interrupt_list:
27
+ if Auth0Interrupt.is_interrupt(interrupt["value"]):
28
+ result.append(interrupt)
29
+
30
+ return result
@@ -0,0 +1,34 @@
1
+ from typing import Callable
2
+ from typing_extensions import Annotated
3
+ from pydantic import create_model
4
+ from langchain_core.tools import BaseTool, tool as create_tool, InjectedToolCallId
5
+ from langchain_core.runnables import RunnableConfig
6
+
7
+ def tool_wrapper(tool: BaseTool, protect_fn: Callable) -> BaseTool:
8
+
9
+ # Workaround: extend existing args_schema to be able to get the tool_call_id value
10
+ args_schema = create_model(
11
+ tool.args_schema.__name__ + "Extended",
12
+ __base__=tool.args_schema,
13
+ **{"tool_call_id": (Annotated[str, InjectedToolCallId])}
14
+ )
15
+
16
+ @create_tool(
17
+ tool.name,
18
+ description=tool.description,
19
+ args_schema=args_schema
20
+ )
21
+ async def wrapped_tool(config: RunnableConfig, tool_call_id: Annotated[str, InjectedToolCallId], **input):
22
+ async def execute_fn(*_, **__):
23
+ return await tool.ainvoke(input, config)
24
+
25
+ return await protect_fn(
26
+ lambda *_, **__: {
27
+ "thread_id": config.get("configurable", {}).get("thread_id"),
28
+ "tool_call_id": tool_call_id,
29
+ "tool_name": tool.name,
30
+ },
31
+ execute_fn,
32
+ )(**input)
33
+
34
+ return wrapped_tool