latitude-sdk 4.0.0b1__tar.gz → 5.0.0b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/PKG-INFO +8 -8
  2. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/README.md +7 -7
  3. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/pyproject.toml +3 -1
  4. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/scripts/test.py +1 -1
  5. latitude_sdk-5.0.0b1/src/latitude_sdk/__init__.py +7 -0
  6. latitude_sdk-5.0.0b1/src/latitude_sdk/client/client.py +147 -0
  7. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/payloads.py +3 -3
  8. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/__init__.py +1 -0
  9. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/errors.py +1 -0
  10. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/evaluations.py +5 -14
  11. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/latitude.py +1 -0
  12. latitude_sdk-5.0.0b1/src/latitude_sdk/sdk/projects.py +35 -0
  13. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/prompts.py +47 -83
  14. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/types.py +46 -13
  15. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/util/utils.py +5 -0
  16. latitude_sdk-5.0.0b1/src/latitude_sdk/version/__init__.py +1 -0
  17. latitude_sdk-5.0.0b1/src/latitude_sdk/version/version.py +63 -0
  18. latitude_sdk-5.0.0b1/tests/projects/create_test.py +53 -0
  19. latitude_sdk-5.0.0b1/tests/projects/get_all_test.py +57 -0
  20. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/chat_test.py +128 -9
  21. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/run_test.py +187 -25
  22. latitude_sdk-5.0.0b1/tests/sdk/__init__.py +0 -0
  23. latitude_sdk-4.0.0b1/tests/test_acceptance.py → latitude_sdk-5.0.0b1/tests/sdk/acceptance_test.py +50 -39
  24. latitude_sdk-5.0.0b1/tests/sdk/client_test.py +110 -0
  25. latitude_sdk-5.0.0b1/tests/sdk/memory_test.py +376 -0
  26. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/utils/fixtures.py +43 -287
  27. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/utils/utils.py +9 -2
  28. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/uv.lock +18 -1
  29. latitude_sdk-4.0.0b1/src/latitude_sdk/__init__.py +0 -1
  30. latitude_sdk-4.0.0b1/src/latitude_sdk/client/client.py +0 -276
  31. latitude_sdk-4.0.0b1/src/latitude_sdk/sdk/projects.py +0 -41
  32. latitude_sdk-4.0.0b1/tests/projects/create_test.py +0 -101
  33. latitude_sdk-4.0.0b1/tests/projects/get_all_test.py +0 -137
  34. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/.gitignore +0 -0
  35. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/.python-version +0 -0
  36. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/LICENSE.md +0 -0
  37. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/scripts/format.py +0 -0
  38. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/scripts/lint.py +0 -0
  39. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/__init__.py +0 -0
  40. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/router.py +0 -0
  41. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/env/__init__.py +0 -0
  42. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/env/env.py +0 -0
  43. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/py.typed +0 -0
  44. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/logs.py +0 -0
  45. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/util/__init__.py +0 -0
  46. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/__init__.py +0 -0
  47. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/evaluations/__init__.py +0 -0
  48. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/evaluations/annotate_test.py +0 -0
  49. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/logs/__init__.py +0 -0
  50. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/logs/create_test.py +0 -0
  51. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/projects/__init__.py +0 -0
  52. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/__init__.py +0 -0
  53. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/get_all_test.py +0 -0
  54. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/get_or_create_test.py +0 -0
  55. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/get_test.py +0 -0
  56. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/render_chain_test.py +0 -0
  57. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/render_test.py +0 -0
  58. {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: latitude-sdk
3
- Version: 4.0.0b1
3
+ Version: 5.0.0b1
4
4
  Summary: Latitude SDK for Python
5
5
  Project-URL: repository, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python
6
6
  Project-URL: homepage, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python#readme
@@ -62,7 +62,13 @@ Requires uv `0.5.10` or higher.
62
62
 
63
63
  ### Running only a specific test
64
64
 
65
- Mark the test with an `only` marker:
65
+ Specify the test inline:
66
+
67
+ ```python
68
+ uv run scripts/test.py <test_path>::<test_case>::<test_name>
69
+ ```
70
+
71
+ Or mark the test with an `only` marker:
66
72
 
67
73
  ```python
68
74
  import pytest
@@ -78,12 +84,6 @@ async def my_test(self):
78
84
  uv run scripts/test.py -m only
79
85
  ```
80
86
 
81
- Another way is to specify the test in line:
82
-
83
- ```python
84
- uv run scripts/test.py <test_path>::<test_case>::<test_name>
85
- ```
86
-
87
87
  ## License
88
88
 
89
89
  The SDK is licensed under the [MIT License](https://opensource.org/licenses/MIT) - read the [LICENSE](/LICENSE) file for details.
@@ -43,7 +43,13 @@ Requires uv `0.5.10` or higher.
43
43
 
44
44
  ### Running only a specific test
45
45
 
46
- Mark the test with an `only` marker:
46
+ Specify the test inline:
47
+
48
+ ```python
49
+ uv run scripts/test.py <test_path>::<test_case>::<test_name>
50
+ ```
51
+
52
+ Or mark the test with an `only` marker:
47
53
 
48
54
  ```python
49
55
  import pytest
@@ -59,12 +65,6 @@ async def my_test(self):
59
65
  uv run scripts/test.py -m only
60
66
  ```
61
67
 
62
- Another way is to specify the test in line:
63
-
64
- ```python
65
- uv run scripts/test.py <test_path>::<test_case>::<test_name>
66
- ```
67
-
68
68
  ## License
69
69
 
70
70
  The SDK is licensed under the [MIT License](https://opensource.org/licenses/MIT) - read the [LICENSE](/LICENSE) file for details.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "latitude-sdk"
3
- version = "4.0.0-beta.1"
3
+ version = "5.0.0-beta.1"
4
4
  description = "Latitude SDK for Python"
5
5
  authors = [{ name = "Latitude Data SL", email = "hello@latitude.so" }]
6
6
  maintainers = [{ name = "Latitude Data SL", email = "hello@latitude.so" }]
@@ -27,6 +27,7 @@ dev = [
27
27
  "pyright>=1.1.401",
28
28
  "ruff>=0.8.3",
29
29
  "sh>=1.14.3",
30
+ "psutil>=7.0.0",
30
31
  ]
31
32
 
32
33
  [tool.pyright]
@@ -35,6 +36,7 @@ typeCheckingMode = "strict"
35
36
  reportMissingTypeStubs = false
36
37
  reportUnnecessaryIsInstance = false
37
38
  reportPrivateUsage = false
39
+ reportPrivateImportUsage = false
38
40
 
39
41
  [tool.ruff]
40
42
  target-version = "py39"
@@ -4,4 +4,4 @@ from sh import pytest # type: ignore
4
4
 
5
5
  files = sys.argv[1:] or ["."]
6
6
 
7
- pytest(*files, _out=sys.stdout)
7
+ pytest("-rs", *files, _out=sys.stdout)
@@ -0,0 +1,7 @@
1
+ from .version import version
2
+
3
+ __version__ = version.pep440
4
+ __version_info__ = version.info
5
+ __version_semver__ = version.semver
6
+
7
+ from .sdk import *
@@ -0,0 +1,147 @@
1
+ import asyncio
2
+ import json
3
+ from contextlib import asynccontextmanager
4
+ from typing import Any, AsyncGenerator, Optional
5
+
6
+ import httpx
7
+ import httpx_sse
8
+
9
+ from latitude_sdk.client.payloads import ErrorResponse, RequestBody, RequestHandler, RequestParams
10
+ from latitude_sdk.client.router import Router, RouterOptions
11
+ from latitude_sdk.sdk.errors import ApiError, ApiErrorCodes, ApiErrorDbRef
12
+ from latitude_sdk.sdk.types import LogSources
13
+ from latitude_sdk.util import Model
14
+ from latitude_sdk.version import version
15
+
16
+ RETRIABLE_STATUSES = [408, 429, 500, 502, 503, 504]
17
+
18
+ ClientEvent = httpx_sse.ServerSentEvent
19
+
20
+
21
+ class ClientResponse(httpx.Response):
22
+ async def sse(self: httpx.Response) -> AsyncGenerator[ClientEvent, Any]:
23
+ source = httpx_sse.EventSource(self)
24
+
25
+ async for event in source.aiter_sse():
26
+ yield event
27
+
28
+
29
+ httpx.Response.sse = ClientResponse.sse # pyright: ignore [reportAttributeAccessIssue]
30
+
31
+
32
+ class ClientOptions(Model):
33
+ api_key: str
34
+ retries: int
35
+ delay: float
36
+ timeout: Optional[float]
37
+ source: LogSources
38
+ router: RouterOptions
39
+
40
+
41
+ class Client:
42
+ options: ClientOptions
43
+ router: Router
44
+
45
+ def __init__(self, options: ClientOptions):
46
+ self.options = options
47
+ self.router = Router(options.router)
48
+
49
+ @asynccontextmanager
50
+ async def request(
51
+ self,
52
+ handler: RequestHandler,
53
+ params: Optional[RequestParams] = None,
54
+ body: Optional[RequestBody] = None,
55
+ stream: Optional[bool] = None,
56
+ ) -> AsyncGenerator[ClientResponse, Any]:
57
+ client = httpx.AsyncClient(
58
+ headers={
59
+ "Authorization": f"Bearer {self.options.api_key}",
60
+ "X-Latitude-SDK-Version": version.semver,
61
+ "Content-Type": "application/json",
62
+ "Accept": "text/event-stream" if stream else "application/json",
63
+ },
64
+ timeout=self.options.timeout,
65
+ follow_redirects=False,
66
+ max_redirects=0,
67
+ )
68
+ response = None
69
+ attempt = 1
70
+
71
+ try:
72
+ method, url = self.router.resolve(handler, params)
73
+ content = None
74
+ if body:
75
+ content = json.dumps(
76
+ {
77
+ **json.loads(body.model_dump_json()),
78
+ "__internal": {"source": self.options.source},
79
+ }
80
+ )
81
+
82
+ while attempt <= self.options.retries:
83
+ try:
84
+ request = client.build_request(method=method, url=url, content=content)
85
+ response = await client.send(request=request, stream=stream or False)
86
+ response.raise_for_status()
87
+
88
+ yield response # pyright: ignore [reportReturnType]
89
+ break
90
+
91
+ except Exception as exception:
92
+ if isinstance(exception, ApiError):
93
+ raise exception
94
+
95
+ if attempt >= self.options.retries:
96
+ raise await self._exception(exception, response) from exception
97
+
98
+ if response and response.status_code in RETRIABLE_STATUSES:
99
+ await asyncio.sleep(self.options.delay * (2 ** (attempt - 1)))
100
+ else:
101
+ raise await self._exception(exception, response) from exception
102
+
103
+ finally:
104
+ if response:
105
+ await response.aclose()
106
+
107
+ attempt += 1
108
+
109
+ except Exception as exception:
110
+ if isinstance(exception, ApiError):
111
+ raise exception
112
+
113
+ raise await self._exception(exception, response) from exception
114
+
115
+ finally:
116
+ await client.aclose()
117
+
118
+ async def _exception(self, exception: Exception, response: Optional[httpx.Response] = None) -> ApiError:
119
+ if not response:
120
+ return ApiError(
121
+ status=500,
122
+ code=ApiErrorCodes.InternalServerError,
123
+ message=str(exception),
124
+ response=str(exception),
125
+ )
126
+
127
+ try:
128
+ if not response.is_stream_consumed:
129
+ await response.aread()
130
+
131
+ error = ErrorResponse.model_validate_json(response.content)
132
+
133
+ return ApiError(
134
+ status=response.status_code,
135
+ code=error.code,
136
+ message=error.message,
137
+ response=response.text,
138
+ db_ref=ApiErrorDbRef(**dict(error.db_ref)) if error.db_ref else None,
139
+ )
140
+
141
+ except Exception:
142
+ return ApiError(
143
+ status=response.status_code,
144
+ code=ApiErrorCodes.InternalServerError,
145
+ message=str(exception),
146
+ response=response.text,
147
+ )
@@ -44,8 +44,8 @@ class RunPromptRequestBody(Model):
44
44
  path: str
45
45
  parameters: Optional[Dict[str, Any]] = None
46
46
  custom_identifier: Optional[str] = Field(default=None, alias=str("customIdentifier"))
47
- stream: Optional[bool] = None
48
47
  tools: Optional[List[str]] = None
48
+ stream: Optional[bool] = None
49
49
 
50
50
 
51
51
  class ChatPromptRequestParams(Model):
@@ -54,6 +54,7 @@ class ChatPromptRequestParams(Model):
54
54
 
55
55
  class ChatPromptRequestBody(Model):
56
56
  messages: List[Message]
57
+ tools: Optional[List[str]] = None
57
58
  stream: Optional[bool] = None
58
59
 
59
60
 
@@ -81,11 +82,10 @@ class AnnotateEvaluationRequestParams(EvaluationRequestParams, Model):
81
82
 
82
83
 
83
84
  class AnnotateEvaluationRequestBody(Model):
84
- score: int
85
-
86
85
  class Metadata(Model):
87
86
  reason: str
88
87
 
88
+ score: int
89
89
  metadata: Optional[Metadata] = None
90
90
 
91
91
 
@@ -2,5 +2,6 @@ from .errors import *
2
2
  from .evaluations import *
3
3
  from .latitude import *
4
4
  from .logs import *
5
+ from .projects import *
5
6
  from .prompts import *
6
7
  from .types import *
@@ -7,6 +7,7 @@ from latitude_sdk.util import Model, StrEnum
7
7
  class ApiErrorCodes(StrEnum):
8
8
  # LatitudeErrorCodes
9
9
  NotFoundError = "NotFoundError"
10
+ BadRequestError = "BadRequestError"
10
11
 
11
12
  # RunErrorCodes
12
13
  AIRunError = "ai_run_error"
@@ -1,5 +1,4 @@
1
- from datetime import datetime
2
- from typing import Any, Optional, Union
1
+ from typing import Optional
3
2
 
4
3
  from latitude_sdk.client import (
5
4
  AnnotateEvaluationRequestBody,
@@ -7,24 +6,16 @@ from latitude_sdk.client import (
7
6
  Client,
8
7
  RequestHandler,
9
8
  )
10
- from latitude_sdk.sdk.types import SdkOptions
11
- from latitude_sdk.util import Field, Model
9
+ from latitude_sdk.sdk.types import EvaluationResult, SdkOptions
10
+ from latitude_sdk.util import Model
12
11
 
13
12
 
14
13
  class AnnotateEvaluationOptions(Model):
15
14
  reason: str
16
15
 
17
16
 
18
- class AnnotateEvaluationResult(Model):
19
- uuid: str
20
- score: int
21
- normalized_score: int = Field(alias=str("normalizedScore"))
22
- metadata: dict[str, Any]
23
- has_passed: bool = Field(alias=str("hasPassed"))
24
- created_at: datetime = Field(alias=str("createdAt"))
25
- updated_at: datetime = Field(alias=str("updatedAt"))
26
- version_uuid: str = Field(alias=str("versionUuid"))
27
- error: Optional[Union[str, None]] = None
17
+ class AnnotateEvaluationResult(EvaluationResult, Model):
18
+ pass
28
19
 
29
20
 
30
21
  class Evaluations:
@@ -78,6 +78,7 @@ class Latitude:
78
78
  )
79
79
 
80
80
  self.promptl = Promptl(self._options.promptl)
81
+
81
82
  self.projects = Projects(self._client, self._options)
82
83
  self.prompts = Prompts(self._client, self.promptl, self._options)
83
84
  self.logs = Logs(self._client, self._options)
@@ -0,0 +1,35 @@
1
+ from typing import List
2
+
3
+ from latitude_sdk.client import Client, CreateProjectRequestBody, RequestHandler
4
+ from latitude_sdk.sdk.types import Project, SdkOptions, Version
5
+ from latitude_sdk.util import Adapter as AdapterUtil
6
+ from latitude_sdk.util import Model
7
+
8
+ _GetAllProjectResults = AdapterUtil[List[Project]](List[Project])
9
+
10
+
11
+ class CreateProjectResult(Model):
12
+ project: Project
13
+ version: Version
14
+
15
+
16
+ class Projects:
17
+ _options: SdkOptions
18
+ _client: Client
19
+
20
+ def __init__(self, client: Client, options: SdkOptions):
21
+ self._options = options
22
+ self._client = client
23
+
24
+ async def get_all(self) -> List[Project]:
25
+ async with self._client.request(
26
+ handler=RequestHandler.GetAllProjects,
27
+ ) as response:
28
+ return _GetAllProjectResults.validate_json(response.content)
29
+
30
+ async def create(self, name: str) -> CreateProjectResult:
31
+ async with self._client.request(
32
+ handler=RequestHandler.CreateProject,
33
+ body=CreateProjectRequestBody(name=name),
34
+ ) as response:
35
+ return CreateProjectResult.model_validate_json(response.content)
@@ -1,11 +1,6 @@
1
- from typing import Any, AsyncGenerator, Callable, List, Optional, Sequence
1
+ from typing import Any, AsyncGenerator, List, Optional, Sequence
2
2
 
3
- from promptl_ai import (
4
- Adapter,
5
- Message,
6
- MessageLike,
7
- Promptl,
8
- )
3
+ from promptl_ai import Adapter, Message, MessageLike, Promptl
9
4
  from promptl_ai.bindings.types import _Message
10
5
 
11
6
  from latitude_sdk.client import (
@@ -30,6 +25,8 @@ from latitude_sdk.sdk.types import (
30
25
  OnToolCall,
31
26
  OnToolCallDetails,
32
27
  Prompt,
28
+ ProviderEvents,
29
+ ProviderEventToolCalled,
33
30
  Providers,
34
31
  SdkOptions,
35
32
  StreamCallbacks,
@@ -145,7 +142,7 @@ class Prompts:
145
142
  self,
146
143
  stream: AsyncGenerator[ClientEvent, Any],
147
144
  on_event: Optional[StreamCallbacks.OnEvent],
148
- on_tool_call: Optional[Callable[[dict[str, Any]], Any]] = None,
145
+ tools: Optional[dict[str, OnToolCall]],
149
146
  ) -> FinishedResult:
150
147
  uuid = None
151
148
  conversation: List[Message] = []
@@ -153,6 +150,7 @@ class Prompts:
153
150
 
154
151
  async for stream_event in stream:
155
152
  event = None
153
+ tool_call = None
156
154
 
157
155
  if stream_event.event == str(StreamEvents.Latitude):
158
156
  event = _LatitudeEvent.validate_json(stream_event.data)
@@ -174,9 +172,8 @@ class Prompts:
174
172
  event = stream_event.json()
175
173
  event["event"] = StreamEvents.Provider
176
174
 
177
- # Handle tool calls when received in the stream
178
- if on_tool_call and event.get("type") == "tool-call":
179
- await on_tool_call(event)
175
+ if event.get("type") == str(ProviderEvents.ToolCalled):
176
+ tool_call = ProviderEventToolCalled.model_validate_json(stream_event.data)
180
177
 
181
178
  else:
182
179
  raise ApiError(
@@ -189,6 +186,9 @@ class Prompts:
189
186
  if on_event:
190
187
  on_event(event)
191
188
 
189
+ if tool_call:
190
+ await self._handle_tool_call(tool_call, tools)
191
+
192
192
  if not uuid or not response:
193
193
  raise ApiError(
194
194
  status=500,
@@ -197,11 +197,7 @@ class Prompts:
197
197
  response="Stream ended without a chain-complete event. Missing uuid or response.",
198
198
  )
199
199
 
200
- return FinishedResult(
201
- uuid=uuid,
202
- conversation=conversation,
203
- response=response,
204
- )
200
+ return FinishedResult(uuid=uuid, conversation=conversation, response=response)
205
201
 
206
202
  @staticmethod
207
203
  async def _wrap_tool_handler(
@@ -212,66 +208,45 @@ class Prompts:
212
208
  try:
213
209
  result = await handler(arguments, details)
214
210
 
215
- return ToolResult(**tool_result, result=result)
211
+ return ToolResult(**tool_result, result=result, is_error=False)
216
212
  except Exception as exception:
217
213
  return ToolResult(**tool_result, result=str(exception), is_error=True)
218
214
 
219
215
  async def _handle_tool_call(
220
- self,
221
- event: dict[str, Any],
222
- tools: dict[str, OnToolCall],
216
+ self, tool_call: ProviderEventToolCalled, tools: Optional[dict[str, OnToolCall]]
223
217
  ) -> None:
224
- toolCallId: str = event["toolCallId"]
225
- toolName: str = event["toolName"]
226
- args: dict[str, Any] = event["args"]
227
-
228
- tool = tools.get(toolName)
229
- if not tool:
218
+ # NOTE: Do not handle tool calls if user specified no tools
219
+ if not tools:
230
220
  return
231
221
 
222
+ tool_handler = tools.get(tool_call.name)
223
+ if not tool_handler:
224
+ raise ApiError(
225
+ status=400,
226
+ code=ApiErrorCodes.AIRunError,
227
+ message=f"Tool {tool_call.name} not supplied",
228
+ response=f"Tool {tool_call.name} not supplied",
229
+ )
230
+
232
231
  tool_result = await self._wrap_tool_handler(
233
- tool,
234
- args,
232
+ tool_handler,
233
+ tool_call.arguments,
235
234
  OnToolCallDetails(
236
- id=toolCallId,
237
- name=toolName,
238
- arguments=args,
235
+ id=tool_call.id,
236
+ name=tool_call.name,
237
+ arguments=tool_call.arguments,
239
238
  ),
240
239
  )
241
240
 
242
- try:
243
- async with self._client.request(
244
- handler=RequestHandler.ToolResults,
245
- params=None,
246
- body=ToolResultsRequestBody(
247
- tool_call_id=toolCallId,
248
- result=tool_result.result,
249
- is_error=tool_result.is_error,
250
- ),
251
- ) as _:
252
- pass
253
- except Exception as exception:
254
- if not isinstance(exception, ApiError):
255
- exception = ApiError(
256
- status=500,
257
- code=ApiErrorCodes.InternalServerError,
258
- message=str(exception),
259
- response=str(exception),
260
- )
261
-
262
- # Add context about which tool failed
263
- message = f"Failed to execute tool {toolName}. \nLatitude API returned the following error:\
264
- \n\n{exception.message}"
265
-
266
- raise ApiError(
267
- status=exception.status,
268
- code=exception.code,
269
- message=message,
270
- response=exception.response,
271
- ) from exception
272
-
273
- def _on_tool_call(self, tools: dict[str, OnToolCall]) -> Callable[[dict[str, Any]], Any]:
274
- return lambda event: self._handle_tool_call(event, tools)
241
+ async with self._client.request(
242
+ handler=RequestHandler.ToolResults,
243
+ body=ToolResultsRequestBody(
244
+ tool_call_id=tool_call.id,
245
+ result=tool_result.result,
246
+ is_error=tool_result.is_error,
247
+ ),
248
+ ):
249
+ pass
275
250
 
276
251
  async def get(self, path: str, options: Optional[GetPromptOptions] = None) -> GetPromptResult:
277
252
  options = GetPromptOptions(**{**dict(self._options), **dict(options or {})})
@@ -338,16 +313,13 @@ class Prompts:
338
313
  path=path,
339
314
  parameters=options.parameters,
340
315
  custom_identifier=options.custom_identifier,
341
- stream=options.stream,
342
316
  tools=list(options.tools.keys()) if options.tools and options.stream else None,
317
+ stream=options.stream,
343
318
  ),
319
+ stream=options.stream,
344
320
  ) as response:
345
321
  if options.stream:
346
- result = await self._handle_stream(
347
- response.sse(),
348
- options.on_event,
349
- self._on_tool_call(options.tools if options.tools else {}),
350
- )
322
+ result = await self._handle_stream(response.sse(), options.on_event, options.tools)
351
323
  else:
352
324
  result = RunPromptResult.model_validate_json(response.content)
353
325
 
@@ -373,10 +345,7 @@ class Prompts:
373
345
  return None
374
346
 
375
347
  async def chat(
376
- self,
377
- uuid: str,
378
- messages: Sequence[MessageLike],
379
- options: Optional[ChatPromptOptions] = None,
348
+ self, uuid: str, messages: Sequence[MessageLike], options: Optional[ChatPromptOptions] = None
380
349
  ) -> Optional[ChatPromptResult]:
381
350
  options = ChatPromptOptions(**{**dict(self._options), **dict(options or {})})
382
351
 
@@ -390,15 +359,13 @@ class Prompts:
390
359
  ),
391
360
  body=ChatPromptRequestBody(
392
361
  messages=messages,
362
+ tools=list(options.tools.keys()) if options.tools and options.stream else None,
393
363
  stream=options.stream,
394
364
  ),
365
+ stream=options.stream,
395
366
  ) as response:
396
367
  if options.stream:
397
- result = await self._handle_stream(
398
- response.sse(),
399
- options.on_event,
400
- self._on_tool_call(options.tools if options.tools else {}),
401
- )
368
+ result = await self._handle_stream(response.sse(), options.on_event, options.tools)
402
369
  else:
403
370
  result = ChatPromptResult.model_validate_json(response.content)
404
371
 
@@ -451,10 +418,7 @@ class Prompts:
451
418
  )
452
419
 
453
420
  async def render_chain(
454
- self,
455
- prompt: Prompt,
456
- on_step: OnStep,
457
- options: Optional[RenderChainOptions] = None,
421
+ self, prompt: Prompt, on_step: OnStep, options: Optional[RenderChainOptions] = None
458
422
  ) -> RenderChainResult:
459
423
  options = RenderChainOptions(**{**dict(self._options), **dict(options or {})})
460
424
  adapter = options.adapter or _PROVIDER_TO_ADAPTER.get(prompt.provider or Providers.OpenAI, Adapter.OpenAI)