beamlit 0.0.28rc24__py3-none-any.whl → 0.0.29__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,93 @@
1
+ import asyncio
2
+ from dataclasses import dataclass
3
+ import warnings
4
+ from typing import Any, Callable
5
+
6
+ import pydantic
7
+ import typing_extensions as t
8
+ from beamlit.api.agents import list_agents
9
+ from beamlit.authentication.authentication import AuthenticatedClient
10
+ from beamlit.models import Agent, AgentChain
11
+ from beamlit.run import RunClient
12
+ from beamlit.common.settings import get_settings
13
+ from langchain_core.tools.base import BaseTool, ToolException
14
+
15
+
16
+
17
+ class ChainTool(BaseTool):
18
+ """
19
+ Chain tool
20
+ """
21
+
22
+ client: RunClient
23
+ handle_tool_error: bool | str | Callable[[ToolException], str] | None = True
24
+
25
+ @t.override
26
+ def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
27
+ warnings.warn(
28
+ "Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy standard tests.",
29
+ stacklevel=1,
30
+ )
31
+ return asyncio.run(self._arun(*args, **kwargs))
32
+
33
+ @t.override
34
+ async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
35
+ settings = get_settings()
36
+ result = self.client.run(
37
+ "agent",
38
+ self.name,
39
+ settings.environment,
40
+ "POST",
41
+ json=kwargs,
42
+ )
43
+ return result.text
44
+
45
+ @t.override
46
+ @property
47
+ def tool_call_schema(self) -> type[pydantic.BaseModel]:
48
+ assert self.args_schema is not None # noqa: S101
49
+ return self.args_schema
50
+
51
+ class ChainInput(pydantic.BaseModel):
52
+ inputs: str
53
+
54
+ @dataclass
55
+ class ChainToolkit:
56
+ """
57
+ Remote toolkit
58
+ """
59
+ client: AuthenticatedClient
60
+ chain: list[AgentChain]
61
+ _chain: list[Agent] | None = None
62
+
63
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
64
+
65
+ def initialize(self) -> None:
66
+ """Initialize the session and retrieve tools list"""
67
+ if self._chain is None:
68
+ agents = list_agents.sync_detailed(
69
+ client=self.client,
70
+ ).parsed
71
+ chain_enabled = [chain for chain in self.chain if chain.enabled]
72
+ agents_chain = []
73
+ for chain in chain_enabled:
74
+ agent = [agent for agent in agents if agent.metadata.name == chain.name]
75
+ if agent:
76
+ agent[0].spec.description = chain.description or agent[0].spec.description
77
+ agents_chain.append(agent[0])
78
+ self._chain = agents_chain
79
+
80
+ @t.override
81
+ def get_tools(self) -> list[BaseTool]:
82
+ if self._chain is None:
83
+ raise RuntimeError("Must initialize the toolkit first")
84
+
85
+ return [
86
+ ChainTool(
87
+ client=RunClient(self.client),
88
+ name=agent.metadata.name,
89
+ description=agent.spec.description or "",
90
+ args_schema=ChainInput,
91
+ )
92
+ for agent in self._chain
93
+ ]
@@ -17,16 +17,18 @@ from langchain_core.tools import Tool
17
17
  from langgraph.checkpoint.memory import MemorySaver
18
18
  from langgraph.prebuilt import create_react_agent
19
19
 
20
+ from .chain import ChainToolkit
20
21
  from .chat import get_chat_model
21
22
 
22
23
 
23
- def get_functions(dir="src/functions", from_decorator="function"):
24
+ def get_functions(dir="src/functions", from_decorator="function", remote_functions_empty=True):
24
25
  functions = []
25
26
  logger = getLogger(__name__)
26
27
 
27
28
  # Walk through all Python files in functions directory and subdirectories
28
29
  if not os.path.exists(dir):
29
- logger.warn(f"Functions directory {dir} not found")
30
+ if remote_functions_empty:
31
+ logger.warn(f"Functions directory {dir} not found")
30
32
  return []
31
33
  for root, _, files in os.walk(dir):
32
34
  for file in files:
@@ -74,6 +76,7 @@ def get_functions(dir="src/functions", from_decorator="function"):
74
76
  kit_functions = get_functions(
75
77
  dir=os.path.join(root),
76
78
  from_decorator="kit",
79
+ remote_functions_empty=remote_functions_empty,
77
80
  )
78
81
  functions.extend(kit_functions)
79
82
 
@@ -134,7 +137,10 @@ def agent(
134
137
  return wrapped
135
138
 
136
139
  # Initialize functions array to store decorated functions
137
- functions = get_functions(dir=settings.agent.functions_directory)
140
+ functions = get_functions(
141
+ dir=settings.agent.functions_directory,
142
+ remote_functions_empty=not remote_functions,
143
+ )
138
144
  settings.agent.functions = functions
139
145
 
140
146
  if agent is not None:
@@ -178,7 +184,6 @@ def agent(
178
184
  logger.warn(f"Failed to initialize MCP server {server}: {e!s}")
179
185
 
180
186
  if remote_functions:
181
-
182
187
  for function in remote_functions:
183
188
  try:
184
189
  toolkit = RemoteToolkit(client, function)
@@ -187,6 +192,11 @@ def agent(
187
192
  except Exception as e:
188
193
  logger.warn(f"Failed to initialize remote function {function}: {e!s}")
189
194
 
195
+ if agent.spec.agent_chain:
196
+ toolkit = ChainToolkit(client, agent.spec.agent_chain)
197
+ toolkit.initialize()
198
+ functions.extend(toolkit.get_tools())
199
+
190
200
  if override_agent is None and len(functions) == 0:
191
201
  raise ValueError(
192
202
  "You must define at least one function, you can define this function in directory "
@@ -0,0 +1,171 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, Optional, Union
3
+
4
+ import httpx
5
+
6
+ from ... import errors
7
+ from ...client import AuthenticatedClient, Client
8
+ from ...models.trace_ids_response import TraceIdsResponse
9
+ from ...types import UNSET, Response, Unset
10
+
11
+
12
+ def _get_kwargs(
13
+ agent_name: str,
14
+ *,
15
+ environment: Union[Unset, str] = UNSET,
16
+ ) -> dict[str, Any]:
17
+ params: dict[str, Any] = {}
18
+
19
+ params["environment"] = environment
20
+
21
+ params = {k: v for k, v in params.items() if v is not UNSET and v is not None}
22
+
23
+ _kwargs: dict[str, Any] = {
24
+ "method": "get",
25
+ "url": f"/agents/{agent_name}/traces",
26
+ "params": params,
27
+ }
28
+
29
+ return _kwargs
30
+
31
+
32
+ def _parse_response(
33
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
34
+ ) -> Optional[TraceIdsResponse]:
35
+ if response.status_code == 200:
36
+ response_200 = TraceIdsResponse.from_dict(response.json())
37
+
38
+ return response_200
39
+ if client.raise_on_unexpected_status:
40
+ raise errors.UnexpectedStatus(response.status_code, response.content)
41
+ else:
42
+ return None
43
+
44
+
45
+ def _build_response(
46
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
47
+ ) -> Response[TraceIdsResponse]:
48
+ return Response(
49
+ status_code=HTTPStatus(response.status_code),
50
+ content=response.content,
51
+ headers=response.headers,
52
+ parsed=_parse_response(client=client, response=response),
53
+ )
54
+
55
+
56
+ def sync_detailed(
57
+ agent_name: str,
58
+ *,
59
+ client: AuthenticatedClient,
60
+ environment: Union[Unset, str] = UNSET,
61
+ ) -> Response[TraceIdsResponse]:
62
+ """Get agent trace IDs
63
+
64
+ Args:
65
+ agent_name (str):
66
+ environment (Union[Unset, str]):
67
+
68
+ Raises:
69
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
70
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
71
+
72
+ Returns:
73
+ Response[TraceIdsResponse]
74
+ """
75
+
76
+ kwargs = _get_kwargs(
77
+ agent_name=agent_name,
78
+ environment=environment,
79
+ )
80
+
81
+ response = client.get_httpx_client().request(
82
+ **kwargs,
83
+ )
84
+
85
+ return _build_response(client=client, response=response)
86
+
87
+
88
+ def sync(
89
+ agent_name: str,
90
+ *,
91
+ client: AuthenticatedClient,
92
+ environment: Union[Unset, str] = UNSET,
93
+ ) -> Optional[TraceIdsResponse]:
94
+ """Get agent trace IDs
95
+
96
+ Args:
97
+ agent_name (str):
98
+ environment (Union[Unset, str]):
99
+
100
+ Raises:
101
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
102
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
103
+
104
+ Returns:
105
+ TraceIdsResponse
106
+ """
107
+
108
+ return sync_detailed(
109
+ agent_name=agent_name,
110
+ client=client,
111
+ environment=environment,
112
+ ).parsed
113
+
114
+
115
+ async def asyncio_detailed(
116
+ agent_name: str,
117
+ *,
118
+ client: AuthenticatedClient,
119
+ environment: Union[Unset, str] = UNSET,
120
+ ) -> Response[TraceIdsResponse]:
121
+ """Get agent trace IDs
122
+
123
+ Args:
124
+ agent_name (str):
125
+ environment (Union[Unset, str]):
126
+
127
+ Raises:
128
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
129
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
130
+
131
+ Returns:
132
+ Response[TraceIdsResponse]
133
+ """
134
+
135
+ kwargs = _get_kwargs(
136
+ agent_name=agent_name,
137
+ environment=environment,
138
+ )
139
+
140
+ response = await client.get_async_httpx_client().request(**kwargs)
141
+
142
+ return _build_response(client=client, response=response)
143
+
144
+
145
+ async def asyncio(
146
+ agent_name: str,
147
+ *,
148
+ client: AuthenticatedClient,
149
+ environment: Union[Unset, str] = UNSET,
150
+ ) -> Optional[TraceIdsResponse]:
151
+ """Get agent trace IDs
152
+
153
+ Args:
154
+ agent_name (str):
155
+ environment (Union[Unset, str]):
156
+
157
+ Raises:
158
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
159
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
160
+
161
+ Returns:
162
+ TraceIdsResponse
163
+ """
164
+
165
+ return (
166
+ await asyncio_detailed(
167
+ agent_name=agent_name,
168
+ client=client,
169
+ environment=environment,
170
+ )
171
+ ).parsed
File without changes
@@ -0,0 +1,150 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, Optional, Union
3
+
4
+ import httpx
5
+
6
+ from ... import errors
7
+ from ...client import AuthenticatedClient, Client
8
+ from ...models.get_trace_response_200 import GetTraceResponse200
9
+ from ...types import Response
10
+
11
+
12
+ def _get_kwargs(
13
+ trace_id: str,
14
+ ) -> dict[str, Any]:
15
+ _kwargs: dict[str, Any] = {
16
+ "method": "get",
17
+ "url": f"/traces/{trace_id}",
18
+ }
19
+
20
+ return _kwargs
21
+
22
+
23
+ def _parse_response(
24
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
25
+ ) -> Optional[GetTraceResponse200]:
26
+ if response.status_code == 200:
27
+ response_200 = GetTraceResponse200.from_dict(response.json())
28
+
29
+ return response_200
30
+ if client.raise_on_unexpected_status:
31
+ raise errors.UnexpectedStatus(response.status_code, response.content)
32
+ else:
33
+ return None
34
+
35
+
36
+ def _build_response(
37
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
38
+ ) -> Response[GetTraceResponse200]:
39
+ return Response(
40
+ status_code=HTTPStatus(response.status_code),
41
+ content=response.content,
42
+ headers=response.headers,
43
+ parsed=_parse_response(client=client, response=response),
44
+ )
45
+
46
+
47
+ def sync_detailed(
48
+ trace_id: str,
49
+ *,
50
+ client: AuthenticatedClient,
51
+ ) -> Response[GetTraceResponse200]:
52
+ """Get trace by ID
53
+
54
+ Args:
55
+ trace_id (str):
56
+
57
+ Raises:
58
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
59
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
60
+
61
+ Returns:
62
+ Response[GetTraceResponse200]
63
+ """
64
+
65
+ kwargs = _get_kwargs(
66
+ trace_id=trace_id,
67
+ )
68
+
69
+ response = client.get_httpx_client().request(
70
+ **kwargs,
71
+ )
72
+
73
+ return _build_response(client=client, response=response)
74
+
75
+
76
+ def sync(
77
+ trace_id: str,
78
+ *,
79
+ client: AuthenticatedClient,
80
+ ) -> Optional[GetTraceResponse200]:
81
+ """Get trace by ID
82
+
83
+ Args:
84
+ trace_id (str):
85
+
86
+ Raises:
87
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
88
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
89
+
90
+ Returns:
91
+ GetTraceResponse200
92
+ """
93
+
94
+ return sync_detailed(
95
+ trace_id=trace_id,
96
+ client=client,
97
+ ).parsed
98
+
99
+
100
+ async def asyncio_detailed(
101
+ trace_id: str,
102
+ *,
103
+ client: AuthenticatedClient,
104
+ ) -> Response[GetTraceResponse200]:
105
+ """Get trace by ID
106
+
107
+ Args:
108
+ trace_id (str):
109
+
110
+ Raises:
111
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
112
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
113
+
114
+ Returns:
115
+ Response[GetTraceResponse200]
116
+ """
117
+
118
+ kwargs = _get_kwargs(
119
+ trace_id=trace_id,
120
+ )
121
+
122
+ response = await client.get_async_httpx_client().request(**kwargs)
123
+
124
+ return _build_response(client=client, response=response)
125
+
126
+
127
+ async def asyncio(
128
+ trace_id: str,
129
+ *,
130
+ client: AuthenticatedClient,
131
+ ) -> Optional[GetTraceResponse200]:
132
+ """Get trace by ID
133
+
134
+ Args:
135
+ trace_id (str):
136
+
137
+ Raises:
138
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
139
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
140
+
141
+ Returns:
142
+ GetTraceResponse200
143
+ """
144
+
145
+ return (
146
+ await asyncio_detailed(
147
+ trace_id=trace_id,
148
+ client=client,
149
+ )
150
+ ).parsed
@@ -0,0 +1,233 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, Optional, Union
3
+
4
+ import httpx
5
+
6
+ from ... import errors
7
+ from ...client import AuthenticatedClient, Client
8
+ from ...models.get_trace_ids_response_200 import GetTraceIdsResponse200
9
+ from ...types import UNSET, Response, Unset
10
+
11
+
12
+ def _get_kwargs(
13
+ *,
14
+ workload_id: Union[Unset, str] = UNSET,
15
+ workload_type: Union[Unset, str] = UNSET,
16
+ environment: Union[Unset, str] = UNSET,
17
+ limit: Union[Unset, str] = UNSET,
18
+ start_time: Union[Unset, str] = UNSET,
19
+ end_time: Union[Unset, str] = UNSET,
20
+ ) -> dict[str, Any]:
21
+ params: dict[str, Any] = {}
22
+
23
+ params["workloadId"] = workload_id
24
+
25
+ params["workloadType"] = workload_type
26
+
27
+ params["environment"] = environment
28
+
29
+ params["limit"] = limit
30
+
31
+ params["startTime"] = start_time
32
+
33
+ params["endTime"] = end_time
34
+
35
+ params = {k: v for k, v in params.items() if v is not UNSET and v is not None}
36
+
37
+ _kwargs: dict[str, Any] = {
38
+ "method": "get",
39
+ "url": "/traces",
40
+ "params": params,
41
+ }
42
+
43
+ return _kwargs
44
+
45
+
46
+ def _parse_response(
47
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
48
+ ) -> Optional[GetTraceIdsResponse200]:
49
+ if response.status_code == 200:
50
+ response_200 = GetTraceIdsResponse200.from_dict(response.json())
51
+
52
+ return response_200
53
+ if client.raise_on_unexpected_status:
54
+ raise errors.UnexpectedStatus(response.status_code, response.content)
55
+ else:
56
+ return None
57
+
58
+
59
+ def _build_response(
60
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
61
+ ) -> Response[GetTraceIdsResponse200]:
62
+ return Response(
63
+ status_code=HTTPStatus(response.status_code),
64
+ content=response.content,
65
+ headers=response.headers,
66
+ parsed=_parse_response(client=client, response=response),
67
+ )
68
+
69
+
70
+ def sync_detailed(
71
+ *,
72
+ client: AuthenticatedClient,
73
+ workload_id: Union[Unset, str] = UNSET,
74
+ workload_type: Union[Unset, str] = UNSET,
75
+ environment: Union[Unset, str] = UNSET,
76
+ limit: Union[Unset, str] = UNSET,
77
+ start_time: Union[Unset, str] = UNSET,
78
+ end_time: Union[Unset, str] = UNSET,
79
+ ) -> Response[GetTraceIdsResponse200]:
80
+ """Get trace IDs
81
+
82
+ Args:
83
+ workload_id (Union[Unset, str]):
84
+ workload_type (Union[Unset, str]):
85
+ environment (Union[Unset, str]):
86
+ limit (Union[Unset, str]):
87
+ start_time (Union[Unset, str]):
88
+ end_time (Union[Unset, str]):
89
+
90
+ Raises:
91
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
92
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
93
+
94
+ Returns:
95
+ Response[GetTraceIdsResponse200]
96
+ """
97
+
98
+ kwargs = _get_kwargs(
99
+ workload_id=workload_id,
100
+ workload_type=workload_type,
101
+ environment=environment,
102
+ limit=limit,
103
+ start_time=start_time,
104
+ end_time=end_time,
105
+ )
106
+
107
+ response = client.get_httpx_client().request(
108
+ **kwargs,
109
+ )
110
+
111
+ return _build_response(client=client, response=response)
112
+
113
+
114
+ def sync(
115
+ *,
116
+ client: AuthenticatedClient,
117
+ workload_id: Union[Unset, str] = UNSET,
118
+ workload_type: Union[Unset, str] = UNSET,
119
+ environment: Union[Unset, str] = UNSET,
120
+ limit: Union[Unset, str] = UNSET,
121
+ start_time: Union[Unset, str] = UNSET,
122
+ end_time: Union[Unset, str] = UNSET,
123
+ ) -> Optional[GetTraceIdsResponse200]:
124
+ """Get trace IDs
125
+
126
+ Args:
127
+ workload_id (Union[Unset, str]):
128
+ workload_type (Union[Unset, str]):
129
+ environment (Union[Unset, str]):
130
+ limit (Union[Unset, str]):
131
+ start_time (Union[Unset, str]):
132
+ end_time (Union[Unset, str]):
133
+
134
+ Raises:
135
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
136
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
137
+
138
+ Returns:
139
+ GetTraceIdsResponse200
140
+ """
141
+
142
+ return sync_detailed(
143
+ client=client,
144
+ workload_id=workload_id,
145
+ workload_type=workload_type,
146
+ environment=environment,
147
+ limit=limit,
148
+ start_time=start_time,
149
+ end_time=end_time,
150
+ ).parsed
151
+
152
+
153
+ async def asyncio_detailed(
154
+ *,
155
+ client: AuthenticatedClient,
156
+ workload_id: Union[Unset, str] = UNSET,
157
+ workload_type: Union[Unset, str] = UNSET,
158
+ environment: Union[Unset, str] = UNSET,
159
+ limit: Union[Unset, str] = UNSET,
160
+ start_time: Union[Unset, str] = UNSET,
161
+ end_time: Union[Unset, str] = UNSET,
162
+ ) -> Response[GetTraceIdsResponse200]:
163
+ """Get trace IDs
164
+
165
+ Args:
166
+ workload_id (Union[Unset, str]):
167
+ workload_type (Union[Unset, str]):
168
+ environment (Union[Unset, str]):
169
+ limit (Union[Unset, str]):
170
+ start_time (Union[Unset, str]):
171
+ end_time (Union[Unset, str]):
172
+
173
+ Raises:
174
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
175
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
176
+
177
+ Returns:
178
+ Response[GetTraceIdsResponse200]
179
+ """
180
+
181
+ kwargs = _get_kwargs(
182
+ workload_id=workload_id,
183
+ workload_type=workload_type,
184
+ environment=environment,
185
+ limit=limit,
186
+ start_time=start_time,
187
+ end_time=end_time,
188
+ )
189
+
190
+ response = await client.get_async_httpx_client().request(**kwargs)
191
+
192
+ return _build_response(client=client, response=response)
193
+
194
+
195
+ async def asyncio(
196
+ *,
197
+ client: AuthenticatedClient,
198
+ workload_id: Union[Unset, str] = UNSET,
199
+ workload_type: Union[Unset, str] = UNSET,
200
+ environment: Union[Unset, str] = UNSET,
201
+ limit: Union[Unset, str] = UNSET,
202
+ start_time: Union[Unset, str] = UNSET,
203
+ end_time: Union[Unset, str] = UNSET,
204
+ ) -> Optional[GetTraceIdsResponse200]:
205
+ """Get trace IDs
206
+
207
+ Args:
208
+ workload_id (Union[Unset, str]):
209
+ workload_type (Union[Unset, str]):
210
+ environment (Union[Unset, str]):
211
+ limit (Union[Unset, str]):
212
+ start_time (Union[Unset, str]):
213
+ end_time (Union[Unset, str]):
214
+
215
+ Raises:
216
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
217
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
218
+
219
+ Returns:
220
+ GetTraceIdsResponse200
221
+ """
222
+
223
+ return (
224
+ await asyncio_detailed(
225
+ client=client,
226
+ workload_id=workload_id,
227
+ workload_type=workload_type,
228
+ environment=environment,
229
+ limit=limit,
230
+ start_time=start_time,
231
+ end_time=end_time,
232
+ )
233
+ ).parsed