beamlit 0.0.28rc24__py3-none-any.whl → 0.0.29rc26__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,93 @@
1
+ import asyncio
2
+ from dataclasses import dataclass
3
+ import warnings
4
+ from typing import Any, Callable
5
+
6
+ import pydantic
7
+ import typing_extensions as t
8
+ from beamlit.api.agents import list_agents
9
+ from beamlit.authentication.authentication import AuthenticatedClient
10
+ from beamlit.models import Agent, AgentChain
11
+ from beamlit.run import RunClient
12
+ from beamlit.common.settings import get_settings
13
+ from langchain_core.tools.base import BaseTool, ToolException
14
+
15
+
16
+
17
+ class ChainTool(BaseTool):
18
+ """
19
+ Chain tool
20
+ """
21
+
22
+ client: RunClient
23
+ handle_tool_error: bool | str | Callable[[ToolException], str] | None = True
24
+
25
+ @t.override
26
+ def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
27
+ warnings.warn(
28
+ "Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy standard tests.",
29
+ stacklevel=1,
30
+ )
31
+ return asyncio.run(self._arun(*args, **kwargs))
32
+
33
+ @t.override
34
+ async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
35
+ settings = get_settings()
36
+ result = self.client.run(
37
+ "agent",
38
+ self.name,
39
+ settings.environment,
40
+ "POST",
41
+ json=kwargs,
42
+ )
43
+ return result.text
44
+
45
+ @t.override
46
+ @property
47
+ def tool_call_schema(self) -> type[pydantic.BaseModel]:
48
+ assert self.args_schema is not None # noqa: S101
49
+ return self.args_schema
50
+
51
+ class ChainInput(pydantic.BaseModel):
52
+ inputs: str
53
+
54
+ @dataclass
55
+ class ChainToolkit:
56
+ """
57
+ Remote toolkit
58
+ """
59
+ client: AuthenticatedClient
60
+ chain: list[AgentChain]
61
+ _chain: list[Agent] | None = None
62
+
63
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
64
+
65
+ def initialize(self) -> None:
66
+ """Initialize the session and retrieve tools list"""
67
+ if self._chain is None:
68
+ agents = list_agents.sync_detailed(
69
+ client=self.client,
70
+ ).parsed
71
+ chain_enabled = [chain for chain in self.chain if chain.enabled]
72
+ agents_chain = []
73
+ for chain in chain_enabled:
74
+ agent = [agent for agent in agents if agent.metadata.name == chain.name]
75
+ if agent:
76
+ agent[0].spec.description = chain.description or agent[0].spec.description
77
+ agents_chain.append(agent[0])
78
+ self._chain = agents_chain
79
+
80
+ @t.override
81
+ def get_tools(self) -> list[BaseTool]:
82
+ if self._chain is None:
83
+ raise RuntimeError("Must initialize the toolkit first")
84
+
85
+ return [
86
+ ChainTool(
87
+ client=RunClient(self.client),
88
+ name=agent.metadata.name,
89
+ description=agent.spec.description or "",
90
+ args_schema=ChainInput,
91
+ )
92
+ for agent in self._chain
93
+ ]
@@ -18,7 +18,7 @@ from langgraph.checkpoint.memory import MemorySaver
18
18
  from langgraph.prebuilt import create_react_agent
19
19
 
20
20
  from .chat import get_chat_model
21
-
21
+ from .chain import ChainToolkit
22
22
 
23
23
  def get_functions(dir="src/functions", from_decorator="function"):
24
24
  functions = []
@@ -178,7 +178,6 @@ def agent(
178
178
  logger.warn(f"Failed to initialize MCP server {server}: {e!s}")
179
179
 
180
180
  if remote_functions:
181
-
182
181
  for function in remote_functions:
183
182
  try:
184
183
  toolkit = RemoteToolkit(client, function)
@@ -187,6 +186,11 @@ def agent(
187
186
  except Exception as e:
188
187
  logger.warn(f"Failed to initialize remote function {function}: {e!s}")
189
188
 
189
+ if agent.spec.agent_chain:
190
+ toolkit = ChainToolkit(client, agent.spec.agent_chain)
191
+ toolkit.initialize()
192
+ functions.extend(toolkit.get_tools())
193
+
190
194
  if override_agent is None and len(functions) == 0:
191
195
  raise ValueError(
192
196
  "You must define at least one function, you can define this function in directory "
@@ -0,0 +1,106 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, Optional, Union
3
+
4
+ import httpx
5
+
6
+ from ... import errors
7
+ from ...client import AuthenticatedClient, Client
8
+ from ...types import UNSET, Response, Unset
9
+
10
+
11
+ def _get_kwargs(
12
+ agent_name: str,
13
+ *,
14
+ environment: Union[Unset, str] = UNSET,
15
+ ) -> dict[str, Any]:
16
+ params: dict[str, Any] = {}
17
+
18
+ params["environment"] = environment
19
+
20
+ params = {k: v for k, v in params.items() if v is not UNSET and v is not None}
21
+
22
+ _kwargs: dict[str, Any] = {
23
+ "method": "get",
24
+ "url": f"/agents/{agent_name}/traces",
25
+ "params": params,
26
+ }
27
+
28
+ return _kwargs
29
+
30
+
31
+ def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[Any]:
32
+ if client.raise_on_unexpected_status:
33
+ raise errors.UnexpectedStatus(response.status_code, response.content)
34
+ else:
35
+ return None
36
+
37
+
38
+ def _build_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Response[Any]:
39
+ return Response(
40
+ status_code=HTTPStatus(response.status_code),
41
+ content=response.content,
42
+ headers=response.headers,
43
+ parsed=_parse_response(client=client, response=response),
44
+ )
45
+
46
+
47
+ def sync_detailed(
48
+ agent_name: str,
49
+ *,
50
+ client: AuthenticatedClient,
51
+ environment: Union[Unset, str] = UNSET,
52
+ ) -> Response[Any]:
53
+ """Get agent trace IDs
54
+
55
+ Args:
56
+ agent_name (str):
57
+ environment (Union[Unset, str]):
58
+
59
+ Raises:
60
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
61
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
62
+
63
+ Returns:
64
+ Response[Any]
65
+ """
66
+
67
+ kwargs = _get_kwargs(
68
+ agent_name=agent_name,
69
+ environment=environment,
70
+ )
71
+
72
+ response = client.get_httpx_client().request(
73
+ **kwargs,
74
+ )
75
+
76
+ return _build_response(client=client, response=response)
77
+
78
+
79
+ async def asyncio_detailed(
80
+ agent_name: str,
81
+ *,
82
+ client: AuthenticatedClient,
83
+ environment: Union[Unset, str] = UNSET,
84
+ ) -> Response[Any]:
85
+ """Get agent trace IDs
86
+
87
+ Args:
88
+ agent_name (str):
89
+ environment (Union[Unset, str]):
90
+
91
+ Raises:
92
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
93
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
94
+
95
+ Returns:
96
+ Response[Any]
97
+ """
98
+
99
+ kwargs = _get_kwargs(
100
+ agent_name=agent_name,
101
+ environment=environment,
102
+ )
103
+
104
+ response = await client.get_async_httpx_client().request(**kwargs)
105
+
106
+ return _build_response(client=client, response=response)
File without changes
@@ -0,0 +1,150 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, Optional, Union
3
+
4
+ import httpx
5
+
6
+ from ... import errors
7
+ from ...client import AuthenticatedClient, Client
8
+ from ...models.get_trace_response_200 import GetTraceResponse200
9
+ from ...types import Response
10
+
11
+
12
+ def _get_kwargs(
13
+ trace_id: str,
14
+ ) -> dict[str, Any]:
15
+ _kwargs: dict[str, Any] = {
16
+ "method": "get",
17
+ "url": f"/traces/{trace_id}",
18
+ }
19
+
20
+ return _kwargs
21
+
22
+
23
+ def _parse_response(
24
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
25
+ ) -> Optional[GetTraceResponse200]:
26
+ if response.status_code == 200:
27
+ response_200 = GetTraceResponse200.from_dict(response.json())
28
+
29
+ return response_200
30
+ if client.raise_on_unexpected_status:
31
+ raise errors.UnexpectedStatus(response.status_code, response.content)
32
+ else:
33
+ return None
34
+
35
+
36
+ def _build_response(
37
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
38
+ ) -> Response[GetTraceResponse200]:
39
+ return Response(
40
+ status_code=HTTPStatus(response.status_code),
41
+ content=response.content,
42
+ headers=response.headers,
43
+ parsed=_parse_response(client=client, response=response),
44
+ )
45
+
46
+
47
+ def sync_detailed(
48
+ trace_id: str,
49
+ *,
50
+ client: AuthenticatedClient,
51
+ ) -> Response[GetTraceResponse200]:
52
+ """Get trace by ID
53
+
54
+ Args:
55
+ trace_id (str):
56
+
57
+ Raises:
58
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
59
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
60
+
61
+ Returns:
62
+ Response[GetTraceResponse200]
63
+ """
64
+
65
+ kwargs = _get_kwargs(
66
+ trace_id=trace_id,
67
+ )
68
+
69
+ response = client.get_httpx_client().request(
70
+ **kwargs,
71
+ )
72
+
73
+ return _build_response(client=client, response=response)
74
+
75
+
76
+ def sync(
77
+ trace_id: str,
78
+ *,
79
+ client: AuthenticatedClient,
80
+ ) -> Optional[GetTraceResponse200]:
81
+ """Get trace by ID
82
+
83
+ Args:
84
+ trace_id (str):
85
+
86
+ Raises:
87
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
88
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
89
+
90
+ Returns:
91
+ GetTraceResponse200
92
+ """
93
+
94
+ return sync_detailed(
95
+ trace_id=trace_id,
96
+ client=client,
97
+ ).parsed
98
+
99
+
100
+ async def asyncio_detailed(
101
+ trace_id: str,
102
+ *,
103
+ client: AuthenticatedClient,
104
+ ) -> Response[GetTraceResponse200]:
105
+ """Get trace by ID
106
+
107
+ Args:
108
+ trace_id (str):
109
+
110
+ Raises:
111
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
112
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
113
+
114
+ Returns:
115
+ Response[GetTraceResponse200]
116
+ """
117
+
118
+ kwargs = _get_kwargs(
119
+ trace_id=trace_id,
120
+ )
121
+
122
+ response = await client.get_async_httpx_client().request(**kwargs)
123
+
124
+ return _build_response(client=client, response=response)
125
+
126
+
127
+ async def asyncio(
128
+ trace_id: str,
129
+ *,
130
+ client: AuthenticatedClient,
131
+ ) -> Optional[GetTraceResponse200]:
132
+ """Get trace by ID
133
+
134
+ Args:
135
+ trace_id (str):
136
+
137
+ Raises:
138
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
139
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
140
+
141
+ Returns:
142
+ GetTraceResponse200
143
+ """
144
+
145
+ return (
146
+ await asyncio_detailed(
147
+ trace_id=trace_id,
148
+ client=client,
149
+ )
150
+ ).parsed
@@ -0,0 +1,233 @@
1
+ from http import HTTPStatus
2
+ from typing import Any, Optional, Union
3
+
4
+ import httpx
5
+
6
+ from ... import errors
7
+ from ...client import AuthenticatedClient, Client
8
+ from ...models.get_trace_ids_response_200 import GetTraceIdsResponse200
9
+ from ...types import UNSET, Response, Unset
10
+
11
+
12
+ def _get_kwargs(
13
+ *,
14
+ workload_id: Union[Unset, str] = UNSET,
15
+ workload_type: Union[Unset, str] = UNSET,
16
+ environment: Union[Unset, str] = UNSET,
17
+ limit: Union[Unset, str] = UNSET,
18
+ start_time: Union[Unset, str] = UNSET,
19
+ end_time: Union[Unset, str] = UNSET,
20
+ ) -> dict[str, Any]:
21
+ params: dict[str, Any] = {}
22
+
23
+ params["workloadId"] = workload_id
24
+
25
+ params["workloadType"] = workload_type
26
+
27
+ params["environment"] = environment
28
+
29
+ params["limit"] = limit
30
+
31
+ params["startTime"] = start_time
32
+
33
+ params["endTime"] = end_time
34
+
35
+ params = {k: v for k, v in params.items() if v is not UNSET and v is not None}
36
+
37
+ _kwargs: dict[str, Any] = {
38
+ "method": "get",
39
+ "url": "/traces",
40
+ "params": params,
41
+ }
42
+
43
+ return _kwargs
44
+
45
+
46
+ def _parse_response(
47
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
48
+ ) -> Optional[GetTraceIdsResponse200]:
49
+ if response.status_code == 200:
50
+ response_200 = GetTraceIdsResponse200.from_dict(response.json())
51
+
52
+ return response_200
53
+ if client.raise_on_unexpected_status:
54
+ raise errors.UnexpectedStatus(response.status_code, response.content)
55
+ else:
56
+ return None
57
+
58
+
59
+ def _build_response(
60
+ *, client: Union[AuthenticatedClient, Client], response: httpx.Response
61
+ ) -> Response[GetTraceIdsResponse200]:
62
+ return Response(
63
+ status_code=HTTPStatus(response.status_code),
64
+ content=response.content,
65
+ headers=response.headers,
66
+ parsed=_parse_response(client=client, response=response),
67
+ )
68
+
69
+
70
+ def sync_detailed(
71
+ *,
72
+ client: AuthenticatedClient,
73
+ workload_id: Union[Unset, str] = UNSET,
74
+ workload_type: Union[Unset, str] = UNSET,
75
+ environment: Union[Unset, str] = UNSET,
76
+ limit: Union[Unset, str] = UNSET,
77
+ start_time: Union[Unset, str] = UNSET,
78
+ end_time: Union[Unset, str] = UNSET,
79
+ ) -> Response[GetTraceIdsResponse200]:
80
+ """Get trace IDs
81
+
82
+ Args:
83
+ workload_id (Union[Unset, str]):
84
+ workload_type (Union[Unset, str]):
85
+ environment (Union[Unset, str]):
86
+ limit (Union[Unset, str]):
87
+ start_time (Union[Unset, str]):
88
+ end_time (Union[Unset, str]):
89
+
90
+ Raises:
91
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
92
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
93
+
94
+ Returns:
95
+ Response[GetTraceIdsResponse200]
96
+ """
97
+
98
+ kwargs = _get_kwargs(
99
+ workload_id=workload_id,
100
+ workload_type=workload_type,
101
+ environment=environment,
102
+ limit=limit,
103
+ start_time=start_time,
104
+ end_time=end_time,
105
+ )
106
+
107
+ response = client.get_httpx_client().request(
108
+ **kwargs,
109
+ )
110
+
111
+ return _build_response(client=client, response=response)
112
+
113
+
114
+ def sync(
115
+ *,
116
+ client: AuthenticatedClient,
117
+ workload_id: Union[Unset, str] = UNSET,
118
+ workload_type: Union[Unset, str] = UNSET,
119
+ environment: Union[Unset, str] = UNSET,
120
+ limit: Union[Unset, str] = UNSET,
121
+ start_time: Union[Unset, str] = UNSET,
122
+ end_time: Union[Unset, str] = UNSET,
123
+ ) -> Optional[GetTraceIdsResponse200]:
124
+ """Get trace IDs
125
+
126
+ Args:
127
+ workload_id (Union[Unset, str]):
128
+ workload_type (Union[Unset, str]):
129
+ environment (Union[Unset, str]):
130
+ limit (Union[Unset, str]):
131
+ start_time (Union[Unset, str]):
132
+ end_time (Union[Unset, str]):
133
+
134
+ Raises:
135
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
136
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
137
+
138
+ Returns:
139
+ GetTraceIdsResponse200
140
+ """
141
+
142
+ return sync_detailed(
143
+ client=client,
144
+ workload_id=workload_id,
145
+ workload_type=workload_type,
146
+ environment=environment,
147
+ limit=limit,
148
+ start_time=start_time,
149
+ end_time=end_time,
150
+ ).parsed
151
+
152
+
153
+ async def asyncio_detailed(
154
+ *,
155
+ client: AuthenticatedClient,
156
+ workload_id: Union[Unset, str] = UNSET,
157
+ workload_type: Union[Unset, str] = UNSET,
158
+ environment: Union[Unset, str] = UNSET,
159
+ limit: Union[Unset, str] = UNSET,
160
+ start_time: Union[Unset, str] = UNSET,
161
+ end_time: Union[Unset, str] = UNSET,
162
+ ) -> Response[GetTraceIdsResponse200]:
163
+ """Get trace IDs
164
+
165
+ Args:
166
+ workload_id (Union[Unset, str]):
167
+ workload_type (Union[Unset, str]):
168
+ environment (Union[Unset, str]):
169
+ limit (Union[Unset, str]):
170
+ start_time (Union[Unset, str]):
171
+ end_time (Union[Unset, str]):
172
+
173
+ Raises:
174
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
175
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
176
+
177
+ Returns:
178
+ Response[GetTraceIdsResponse200]
179
+ """
180
+
181
+ kwargs = _get_kwargs(
182
+ workload_id=workload_id,
183
+ workload_type=workload_type,
184
+ environment=environment,
185
+ limit=limit,
186
+ start_time=start_time,
187
+ end_time=end_time,
188
+ )
189
+
190
+ response = await client.get_async_httpx_client().request(**kwargs)
191
+
192
+ return _build_response(client=client, response=response)
193
+
194
+
195
+ async def asyncio(
196
+ *,
197
+ client: AuthenticatedClient,
198
+ workload_id: Union[Unset, str] = UNSET,
199
+ workload_type: Union[Unset, str] = UNSET,
200
+ environment: Union[Unset, str] = UNSET,
201
+ limit: Union[Unset, str] = UNSET,
202
+ start_time: Union[Unset, str] = UNSET,
203
+ end_time: Union[Unset, str] = UNSET,
204
+ ) -> Optional[GetTraceIdsResponse200]:
205
+ """Get trace IDs
206
+
207
+ Args:
208
+ workload_id (Union[Unset, str]):
209
+ workload_type (Union[Unset, str]):
210
+ environment (Union[Unset, str]):
211
+ limit (Union[Unset, str]):
212
+ start_time (Union[Unset, str]):
213
+ end_time (Union[Unset, str]):
214
+
215
+ Raises:
216
+ errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True.
217
+ httpx.TimeoutException: If the request takes longer than Client.timeout.
218
+
219
+ Returns:
220
+ GetTraceIdsResponse200
221
+ """
222
+
223
+ return (
224
+ await asyncio_detailed(
225
+ client=client,
226
+ workload_id=workload_id,
227
+ workload_type=workload_type,
228
+ environment=environment,
229
+ limit=limit,
230
+ start_time=start_time,
231
+ end_time=end_time,
232
+ )
233
+ ).parsed