latitude-sdk 3.0.2__tar.gz → 5.0.0b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/PKG-INFO +13 -11
  2. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/README.md +12 -10
  3. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/pyproject.toml +3 -1
  4. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/scripts/test.py +1 -1
  5. latitude_sdk-5.0.0b1/src/latitude_sdk/__init__.py +7 -0
  6. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/client.py +15 -8
  7. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/payloads.py +18 -2
  8. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/router.py +10 -1
  9. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/__init__.py +1 -0
  10. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/errors.py +1 -0
  11. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/evaluations.py +5 -14
  12. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/latitude.py +4 -0
  13. latitude_sdk-5.0.0b1/src/latitude_sdk/sdk/projects.py +35 -0
  14. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/prompts.py +52 -115
  15. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/types.py +62 -27
  16. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/util/utils.py +5 -0
  17. latitude_sdk-5.0.0b1/src/latitude_sdk/version/__init__.py +1 -0
  18. latitude_sdk-5.0.0b1/src/latitude_sdk/version/version.py +63 -0
  19. latitude_sdk-5.0.0b1/tests/projects/create_test.py +53 -0
  20. latitude_sdk-5.0.0b1/tests/projects/get_all_test.py +57 -0
  21. latitude_sdk-5.0.0b1/tests/prompts/__init__.py +0 -0
  22. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/prompts/chat_test.py +37 -193
  23. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/prompts/run_test.py +29 -199
  24. latitude_sdk-5.0.0b1/tests/sdk/__init__.py +0 -0
  25. latitude_sdk-5.0.0b1/tests/sdk/acceptance_test.py +325 -0
  26. latitude_sdk-5.0.0b1/tests/sdk/client_test.py +110 -0
  27. latitude_sdk-5.0.0b1/tests/sdk/memory_test.py +376 -0
  28. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/utils/fixtures.py +99 -470
  29. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/utils/utils.py +9 -2
  30. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/uv.lock +195 -177
  31. latitude_sdk-3.0.2/src/latitude_sdk/__init__.py +0 -1
  32. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/.gitignore +0 -0
  33. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/.python-version +0 -0
  34. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/LICENSE.md +0 -0
  35. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/scripts/format.py +0 -0
  36. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/scripts/lint.py +0 -0
  37. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/__init__.py +0 -0
  38. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/env/__init__.py +0 -0
  39. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/env/env.py +0 -0
  40. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/py.typed +0 -0
  41. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/logs.py +0 -0
  42. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/src/latitude_sdk/util/__init__.py +0 -0
  43. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/__init__.py +0 -0
  44. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/evaluations/__init__.py +0 -0
  45. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/evaluations/annotate_test.py +0 -0
  46. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/logs/__init__.py +0 -0
  47. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/logs/create_test.py +0 -0
  48. {latitude_sdk-3.0.2/tests/prompts → latitude_sdk-5.0.0b1/tests/projects}/__init__.py +0 -0
  49. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/prompts/get_all_test.py +0 -0
  50. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/prompts/get_or_create_test.py +0 -0
  51. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/prompts/get_test.py +0 -0
  52. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/prompts/render_chain_test.py +0 -0
  53. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/prompts/render_test.py +0 -0
  54. {latitude_sdk-3.0.2 → latitude_sdk-5.0.0b1}/tests/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: latitude-sdk
3
- Version: 3.0.2
3
+ Version: 5.0.0b1
4
4
  Summary: Latitude SDK for Python
5
5
  Project-URL: repository, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python
6
6
  Project-URL: homepage, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python#readme
@@ -46,7 +46,7 @@ await sdk.prompts.run("joke-teller", RunPromptOptions(
46
46
  ))
47
47
  ```
48
48
 
49
- Find more [examples](https://github.com/latitude-dev/latitude-llm/tree/main/examples/sdks/python).
49
+ Find more [examples](https://docs.latitude.so/examples/sdk).
50
50
 
51
51
  ## Development
52
52
 
@@ -60,7 +60,15 @@ Requires uv `0.5.10` or higher.
60
60
  - Build package: `uv build`
61
61
  - Publish package: `uv publish`
62
62
 
63
- ## Run only one test
63
+ ### Running only a specific test
64
+
65
+ Specify the test inline:
66
+
67
+ ```python
68
+ uv run scripts/test.py <test_path>::<test_case>::<test_name>
69
+ ```
70
+
71
+ Or mark the test with an `only` marker:
64
72
 
65
73
  ```python
66
74
  import pytest
@@ -70,18 +78,12 @@ async def my_test(self):
70
78
  # ... your code
71
79
  ```
72
80
 
73
- And then run the tests with the marker `only`:
81
+ ...and then run the tests with the marker `only`:
74
82
 
75
83
  ```sh
76
84
  uv run scripts/test.py -m only
77
85
  ```
78
86
 
79
- Other way is all in line:
80
-
81
- ```python
82
- uv run scripts/test.py <test_path>::<test_case>::<test_name>
83
- ```
84
-
85
87
  ## License
86
88
 
87
- The SDK is licensed under the [LGPL-3.0 License](https://opensource.org/licenses/LGPL-3.0) - read the [LICENSE](/LICENSE) file for details.
89
+ The SDK is licensed under the [MIT License](https://opensource.org/licenses/MIT) - read the [LICENSE](/LICENSE) file for details.
@@ -27,7 +27,7 @@ await sdk.prompts.run("joke-teller", RunPromptOptions(
27
27
  ))
28
28
  ```
29
29
 
30
- Find more [examples](https://github.com/latitude-dev/latitude-llm/tree/main/examples/sdks/python).
30
+ Find more [examples](https://docs.latitude.so/examples/sdk).
31
31
 
32
32
  ## Development
33
33
 
@@ -41,7 +41,15 @@ Requires uv `0.5.10` or higher.
41
41
  - Build package: `uv build`
42
42
  - Publish package: `uv publish`
43
43
 
44
- ## Run only one test
44
+ ### Running only a specific test
45
+
46
+ Specify the test inline:
47
+
48
+ ```python
49
+ uv run scripts/test.py <test_path>::<test_case>::<test_name>
50
+ ```
51
+
52
+ Or mark the test with an `only` marker:
45
53
 
46
54
  ```python
47
55
  import pytest
@@ -51,18 +59,12 @@ async def my_test(self):
51
59
  # ... your code
52
60
  ```
53
61
 
54
- And then run the tests with the marker `only`:
62
+ ...and then run the tests with the marker `only`:
55
63
 
56
64
  ```sh
57
65
  uv run scripts/test.py -m only
58
66
  ```
59
67
 
60
- Other way is all in line:
61
-
62
- ```python
63
- uv run scripts/test.py <test_path>::<test_case>::<test_name>
64
- ```
65
-
66
68
  ## License
67
69
 
68
- The SDK is licensed under the [LGPL-3.0 License](https://opensource.org/licenses/LGPL-3.0) - read the [LICENSE](/LICENSE) file for details.
70
+ The SDK is licensed under the [MIT License](https://opensource.org/licenses/MIT) - read the [LICENSE](/LICENSE) file for details.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "latitude-sdk"
3
- version = "3.0.2"
3
+ version = "5.0.0-beta.1"
4
4
  description = "Latitude SDK for Python"
5
5
  authors = [{ name = "Latitude Data SL", email = "hello@latitude.so" }]
6
6
  maintainers = [{ name = "Latitude Data SL", email = "hello@latitude.so" }]
@@ -27,6 +27,7 @@ dev = [
27
27
  "pyright>=1.1.401",
28
28
  "ruff>=0.8.3",
29
29
  "sh>=1.14.3",
30
+ "psutil>=7.0.0",
30
31
  ]
31
32
 
32
33
  [tool.pyright]
@@ -35,6 +36,7 @@ typeCheckingMode = "strict"
35
36
  reportMissingTypeStubs = false
36
37
  reportUnnecessaryIsInstance = false
37
38
  reportPrivateUsage = false
39
+ reportPrivateImportUsage = false
38
40
 
39
41
  [tool.ruff]
40
42
  target-version = "py39"
@@ -4,4 +4,4 @@ from sh import pytest # type: ignore
4
4
 
5
5
  files = sys.argv[1:] or ["."]
6
6
 
7
- pytest(*files, _out=sys.stdout)
7
+ pytest("-rs", *files, _out=sys.stdout)
@@ -0,0 +1,7 @@
1
+ from .version import version
2
+
3
+ __version__ = version.pep440
4
+ __version_info__ = version.info
5
+ __version_semver__ = version.semver
6
+
7
+ from .sdk import *
@@ -8,15 +8,12 @@ import httpx_sse
8
8
 
9
9
  from latitude_sdk.client.payloads import ErrorResponse, RequestBody, RequestHandler, RequestParams
10
10
  from latitude_sdk.client.router import Router, RouterOptions
11
- from latitude_sdk.sdk.errors import (
12
- ApiError,
13
- ApiErrorCodes,
14
- ApiErrorDbRef,
15
- )
11
+ from latitude_sdk.sdk.errors import ApiError, ApiErrorCodes, ApiErrorDbRef
16
12
  from latitude_sdk.sdk.types import LogSources
17
13
  from latitude_sdk.util import Model
14
+ from latitude_sdk.version import version
18
15
 
19
- RETRIABLE_STATUSES = [408, 409, 429, 500, 502, 503, 504]
16
+ RETRIABLE_STATUSES = [408, 429, 500, 502, 503, 504]
20
17
 
21
18
  ClientEvent = httpx_sse.ServerSentEvent
22
19
 
@@ -51,12 +48,18 @@ class Client:
51
48
 
52
49
  @asynccontextmanager
53
50
  async def request(
54
- self, handler: RequestHandler, params: RequestParams, body: Optional[RequestBody] = None
51
+ self,
52
+ handler: RequestHandler,
53
+ params: Optional[RequestParams] = None,
54
+ body: Optional[RequestBody] = None,
55
+ stream: Optional[bool] = None,
55
56
  ) -> AsyncGenerator[ClientResponse, Any]:
56
57
  client = httpx.AsyncClient(
57
58
  headers={
58
59
  "Authorization": f"Bearer {self.options.api_key}",
60
+ "X-Latitude-SDK-Version": version.semver,
59
61
  "Content-Type": "application/json",
62
+ "Accept": "text/event-stream" if stream else "application/json",
60
63
  },
61
64
  timeout=self.options.timeout,
62
65
  follow_redirects=False,
@@ -78,7 +81,8 @@ class Client:
78
81
 
79
82
  while attempt <= self.options.retries:
80
83
  try:
81
- response = await client.request(method=method, url=url, content=content)
84
+ request = client.build_request(method=method, url=url, content=content)
85
+ response = await client.send(request=request, stream=stream or False)
82
86
  response.raise_for_status()
83
87
 
84
88
  yield response # pyright: ignore [reportReturnType]
@@ -121,6 +125,9 @@ class Client:
121
125
  )
122
126
 
123
127
  try:
128
+ if not response.is_stream_consumed:
129
+ await response.aread()
130
+
124
131
  error = ErrorResponse.model_validate_json(response.content)
125
132
 
126
133
  return ApiError(
@@ -44,6 +44,7 @@ class RunPromptRequestBody(Model):
44
44
  path: str
45
45
  parameters: Optional[Dict[str, Any]] = None
46
46
  custom_identifier: Optional[str] = Field(default=None, alias=str("customIdentifier"))
47
+ tools: Optional[List[str]] = None
47
48
  stream: Optional[bool] = None
48
49
 
49
50
 
@@ -53,6 +54,7 @@ class ChatPromptRequestParams(Model):
53
54
 
54
55
  class ChatPromptRequestBody(Model):
55
56
  messages: List[Message]
57
+ tools: Optional[List[str]] = None
56
58
  stream: Optional[bool] = None
57
59
 
58
60
 
@@ -80,14 +82,23 @@ class AnnotateEvaluationRequestParams(EvaluationRequestParams, Model):
80
82
 
81
83
 
82
84
  class AnnotateEvaluationRequestBody(Model):
83
- score: int
84
-
85
85
  class Metadata(Model):
86
86
  reason: str
87
87
 
88
+ score: int
88
89
  metadata: Optional[Metadata] = None
89
90
 
90
91
 
92
+ class ToolResultsRequestBody(Model):
93
+ tool_call_id: str = Field(alias=str("toolCallId"))
94
+ result: Any
95
+ is_error: Optional[bool] = Field(default=None, alias=str("isError"))
96
+
97
+
98
+ class CreateProjectRequestBody(Model):
99
+ name: str
100
+
101
+
91
102
  RequestParams = Union[
92
103
  GetPromptRequestParams,
93
104
  GetAllPromptRequestParams,
@@ -105,6 +116,8 @@ RequestBody = Union[
105
116
  ChatPromptRequestBody,
106
117
  CreateLogRequestBody,
107
118
  AnnotateEvaluationRequestBody,
119
+ ToolResultsRequestBody,
120
+ CreateProjectRequestBody,
108
121
  ]
109
122
 
110
123
 
@@ -116,3 +129,6 @@ class RequestHandler(StrEnum):
116
129
  ChatPrompt = "CHAT_PROMPT"
117
130
  CreateLog = "CREATE_LOG"
118
131
  AnnotateEvaluation = "ANNOTATE_EVALUATION"
132
+ ToolResults = "TOOL_RESULTS"
133
+ GetAllProjects = "GET_ALL_PROJECTS"
134
+ CreateProject = "CREATE_PROJECT"
@@ -27,7 +27,7 @@ class Router:
27
27
  def __init__(self, options: RouterOptions):
28
28
  self.options = options
29
29
 
30
- def resolve(self, handler: RequestHandler, params: RequestParams) -> Tuple[str, str]:
30
+ def resolve(self, handler: RequestHandler, params: Optional[RequestParams] = None) -> Tuple[str, str]:
31
31
  if handler == RequestHandler.GetPrompt:
32
32
  assert isinstance(params, GetPromptRequestParams)
33
33
 
@@ -90,6 +90,15 @@ class Router:
90
90
 
91
91
  return "POST", self.conversations().annotate(params.conversation_uuid, params.evaluation_uuid)
92
92
 
93
+ elif handler == RequestHandler.ToolResults:
94
+ return "POST", f"{self.options.gateway.base_url}/tools/results"
95
+
96
+ elif handler == RequestHandler.GetAllProjects:
97
+ return "GET", f"{self.options.gateway.base_url}/projects"
98
+
99
+ elif handler == RequestHandler.CreateProject:
100
+ return "POST", f"{self.options.gateway.base_url}/projects"
101
+
93
102
  raise TypeError(f"Unknown handler: {handler}")
94
103
 
95
104
  class Conversations(Model):
@@ -2,5 +2,6 @@ from .errors import *
2
2
  from .evaluations import *
3
3
  from .latitude import *
4
4
  from .logs import *
5
+ from .projects import *
5
6
  from .prompts import *
6
7
  from .types import *
@@ -7,6 +7,7 @@ from latitude_sdk.util import Model, StrEnum
7
7
  class ApiErrorCodes(StrEnum):
8
8
  # LatitudeErrorCodes
9
9
  NotFoundError = "NotFoundError"
10
+ BadRequestError = "BadRequestError"
10
11
 
11
12
  # RunErrorCodes
12
13
  AIRunError = "ai_run_error"
@@ -1,5 +1,4 @@
1
- from datetime import datetime
2
- from typing import Any, Optional, Union
1
+ from typing import Optional
3
2
 
4
3
  from latitude_sdk.client import (
5
4
  AnnotateEvaluationRequestBody,
@@ -7,24 +6,16 @@ from latitude_sdk.client import (
7
6
  Client,
8
7
  RequestHandler,
9
8
  )
10
- from latitude_sdk.sdk.types import SdkOptions
11
- from latitude_sdk.util import Field, Model
9
+ from latitude_sdk.sdk.types import EvaluationResult, SdkOptions
10
+ from latitude_sdk.util import Model
12
11
 
13
12
 
14
13
  class AnnotateEvaluationOptions(Model):
15
14
  reason: str
16
15
 
17
16
 
18
- class AnnotateEvaluationResult(Model):
19
- uuid: str
20
- score: int
21
- normalized_score: int = Field(alias=str("normalizedScore"))
22
- metadata: dict[str, Any]
23
- has_passed: bool = Field(alias=str("hasPassed"))
24
- created_at: datetime = Field(alias=str("createdAt"))
25
- updated_at: datetime = Field(alias=str("updatedAt"))
26
- version_uuid: str = Field(alias=str("versionUuid"))
27
- error: Optional[Union[str, None]] = None
17
+ class AnnotateEvaluationResult(EvaluationResult, Model):
18
+ pass
28
19
 
29
20
 
30
21
  class Evaluations:
@@ -6,6 +6,7 @@ from latitude_sdk.client import Client, ClientOptions, RouterOptions
6
6
  from latitude_sdk.env import env
7
7
  from latitude_sdk.sdk.evaluations import Evaluations
8
8
  from latitude_sdk.sdk.logs import Logs
9
+ from latitude_sdk.sdk.projects import Projects
9
10
  from latitude_sdk.sdk.prompts import Prompts
10
11
  from latitude_sdk.sdk.types import GatewayOptions, LogSources, SdkOptions
11
12
  from latitude_sdk.util import Model
@@ -49,6 +50,7 @@ class Latitude:
49
50
 
50
51
  promptl: Promptl
51
52
 
53
+ projects: Projects
52
54
  prompts: Prompts
53
55
  logs: Logs
54
56
  evaluations: Evaluations
@@ -76,6 +78,8 @@ class Latitude:
76
78
  )
77
79
 
78
80
  self.promptl = Promptl(self._options.promptl)
81
+
82
+ self.projects = Projects(self._client, self._options)
79
83
  self.prompts = Prompts(self._client, self.promptl, self._options)
80
84
  self.logs = Logs(self._client, self._options)
81
85
  self.evaluations = Evaluations(self._client, self._options)
@@ -0,0 +1,35 @@
1
+ from typing import List
2
+
3
+ from latitude_sdk.client import Client, CreateProjectRequestBody, RequestHandler
4
+ from latitude_sdk.sdk.types import Project, SdkOptions, Version
5
+ from latitude_sdk.util import Adapter as AdapterUtil
6
+ from latitude_sdk.util import Model
7
+
8
+ _GetAllProjectResults = AdapterUtil[List[Project]](List[Project])
9
+
10
+
11
+ class CreateProjectResult(Model):
12
+ project: Project
13
+ version: Version
14
+
15
+
16
+ class Projects:
17
+ _options: SdkOptions
18
+ _client: Client
19
+
20
+ def __init__(self, client: Client, options: SdkOptions):
21
+ self._options = options
22
+ self._client = client
23
+
24
+ async def get_all(self) -> List[Project]:
25
+ async with self._client.request(
26
+ handler=RequestHandler.GetAllProjects,
27
+ ) as response:
28
+ return _GetAllProjectResults.validate_json(response.content)
29
+
30
+ async def create(self, name: str) -> CreateProjectResult:
31
+ async with self._client.request(
32
+ handler=RequestHandler.CreateProject,
33
+ body=CreateProjectRequestBody(name=name),
34
+ ) as response:
35
+ return CreateProjectResult.model_validate_json(response.content)
@@ -1,7 +1,6 @@
1
- import asyncio
2
- from typing import Any, AsyncGenerator, List, Optional, Sequence, Tuple, Union
1
+ from typing import Any, AsyncGenerator, List, Optional, Sequence
3
2
 
4
- from promptl_ai import Adapter, Message, MessageLike, Promptl, ToolMessage, ToolResultContent
3
+ from promptl_ai import Adapter, Message, MessageLike, Promptl
5
4
  from promptl_ai.bindings.types import _Message
6
5
 
7
6
  from latitude_sdk.client import (
@@ -16,21 +15,22 @@ from latitude_sdk.client import (
16
15
  RequestHandler,
17
16
  RunPromptRequestBody,
18
17
  RunPromptRequestParams,
18
+ ToolResultsRequestBody,
19
19
  )
20
20
  from latitude_sdk.sdk.errors import ApiError, ApiErrorCodes
21
21
  from latitude_sdk.sdk.types import (
22
- AGENT_END_TOOL_NAME,
23
22
  ChainEvents,
24
23
  FinishedResult,
25
24
  OnStep,
26
25
  OnToolCall,
27
26
  OnToolCallDetails,
28
27
  Prompt,
28
+ ProviderEvents,
29
+ ProviderEventToolCalled,
29
30
  Providers,
30
31
  SdkOptions,
31
32
  StreamCallbacks,
32
33
  StreamEvents,
33
- ToolCall,
34
34
  ToolResult,
35
35
  _LatitudeEvent,
36
36
  )
@@ -52,10 +52,6 @@ _PROMPT_ATTR_TO_ADAPTER_ATTR = {
52
52
  }
53
53
 
54
54
 
55
- class OnToolCallPaused(Exception):
56
- pass
57
-
58
-
59
55
  class PromptOptions(Model):
60
56
  project_id: Optional[int] = None
61
57
  version_uuid: Optional[str] = None
@@ -142,31 +138,19 @@ class Prompts:
142
138
  response="Project ID is required",
143
139
  )
144
140
 
145
- async def _extract_agent_tool_requests(
146
- self, tool_requests: List[ToolCall]
147
- ) -> Tuple[List[ToolCall], List[ToolCall]]:
148
- agent: List[ToolCall] = []
149
- other: List[ToolCall] = []
150
-
151
- for tool in tool_requests:
152
- if tool.name == AGENT_END_TOOL_NAME:
153
- agent.append(tool)
154
- else:
155
- other.append(tool)
156
-
157
- return agent, other
158
-
159
141
  async def _handle_stream(
160
- self, stream: AsyncGenerator[ClientEvent, Any], on_event: Optional[StreamCallbacks.OnEvent]
142
+ self,
143
+ stream: AsyncGenerator[ClientEvent, Any],
144
+ on_event: Optional[StreamCallbacks.OnEvent],
145
+ tools: Optional[dict[str, OnToolCall]],
161
146
  ) -> FinishedResult:
162
147
  uuid = None
163
148
  conversation: List[Message] = []
164
149
  response = None
165
- agent_response = None
166
- tool_requests: List[ToolCall] = []
167
150
 
168
151
  async for stream_event in stream:
169
152
  event = None
153
+ tool_call = None
170
154
 
171
155
  if stream_event.event == str(StreamEvents.Latitude):
172
156
  event = _LatitudeEvent.validate_json(stream_event.data)
@@ -176,9 +160,6 @@ class Prompts:
176
160
  if event.type == ChainEvents.ProviderCompleted:
177
161
  response = event.response
178
162
 
179
- elif event.type == ChainEvents.ToolsRequested:
180
- tool_requests = event.tools
181
-
182
163
  elif event.type == ChainEvents.ChainError:
183
164
  raise ApiError(
184
165
  status=400,
@@ -191,6 +172,9 @@ class Prompts:
191
172
  event = stream_event.json()
192
173
  event["event"] = StreamEvents.Provider
193
174
 
175
+ if event.get("type") == str(ProviderEvents.ToolCalled):
176
+ tool_call = ProviderEventToolCalled.model_validate_json(stream_event.data)
177
+
194
178
  else:
195
179
  raise ApiError(
196
180
  status=500,
@@ -202,6 +186,9 @@ class Prompts:
202
186
  if on_event:
203
187
  on_event(event)
204
188
 
189
+ if tool_call:
190
+ await self._handle_tool_call(tool_call, tools)
191
+
205
192
  if not uuid or not response:
206
193
  raise ApiError(
207
194
  status=500,
@@ -210,21 +197,7 @@ class Prompts:
210
197
  response="Stream ended without a chain-complete event. Missing uuid or response.",
211
198
  )
212
199
 
213
- agent_requests, tool_requests = await self._extract_agent_tool_requests(tool_requests)
214
- if len(agent_requests) > 0:
215
- agent_response = agent_requests[0].arguments
216
-
217
- return FinishedResult(
218
- uuid=uuid,
219
- conversation=conversation,
220
- response=response,
221
- agent_response=agent_response,
222
- tool_requests=tool_requests,
223
- )
224
-
225
- @staticmethod
226
- def _pause_tool_execution() -> Any:
227
- raise OnToolCallPaused()
200
+ return FinishedResult(uuid=uuid, conversation=conversation, response=response)
228
201
 
229
202
  @staticmethod
230
203
  async def _wrap_tool_handler(
@@ -235,69 +208,45 @@ class Prompts:
235
208
  try:
236
209
  result = await handler(arguments, details)
237
210
 
238
- return ToolResult(**tool_result, result=result)
211
+ return ToolResult(**tool_result, result=result, is_error=False)
239
212
  except Exception as exception:
240
- if isinstance(exception, OnToolCallPaused):
241
- raise exception
242
-
243
213
  return ToolResult(**tool_result, result=str(exception), is_error=True)
244
214
 
245
- async def _handle_tool_calls(
246
- self, result: FinishedResult, options: Union[RunPromptOptions, ChatPromptOptions]
247
- ) -> Optional[FinishedResult]:
248
- if not options.tools:
215
+ async def _handle_tool_call(
216
+ self, tool_call: ProviderEventToolCalled, tools: Optional[dict[str, OnToolCall]]
217
+ ) -> None:
218
+ # NOTE: Do not handle tool calls if user specified no tools
219
+ if not tools:
220
+ return
221
+
222
+ tool_handler = tools.get(tool_call.name)
223
+ if not tool_handler:
249
224
  raise ApiError(
250
225
  status=400,
251
226
  code=ApiErrorCodes.AIRunError,
252
- message="Tools not supplied",
253
- response="Tools not supplied",
227
+ message=f"Tool {tool_call.name} not supplied",
228
+ response=f"Tool {tool_call.name} not supplied",
254
229
  )
255
230
 
256
- for tool_call in result.tool_requests:
257
- if tool_call.name not in options.tools:
258
- raise ApiError(
259
- status=400,
260
- code=ApiErrorCodes.AIRunError,
261
- message=f"Tool {tool_call.name} not supplied",
262
- response=f"Tool {tool_call.name} not supplied",
263
- )
264
-
265
- tool_results = await asyncio.gather(
266
- *[
267
- self._wrap_tool_handler(
268
- options.tools[tool_call.name],
269
- tool_call.arguments,
270
- OnToolCallDetails(
271
- id=tool_call.id,
272
- name=tool_call.name,
273
- conversation_uuid=result.uuid,
274
- messages=result.conversation,
275
- pause_execution=self._pause_tool_execution,
276
- requested_tool_calls=result.tool_requests,
277
- ),
278
- )
279
- for tool_call in result.tool_requests
280
- ],
281
- return_exceptions=False,
231
+ tool_result = await self._wrap_tool_handler(
232
+ tool_handler,
233
+ tool_call.arguments,
234
+ OnToolCallDetails(
235
+ id=tool_call.id,
236
+ name=tool_call.name,
237
+ arguments=tool_call.arguments,
238
+ ),
282
239
  )
283
240
 
284
- tool_messages = [
285
- ToolMessage(
286
- content=[
287
- ToolResultContent(
288
- id=tool_result.id,
289
- name=tool_result.name,
290
- result=tool_result.result,
291
- is_error=tool_result.is_error,
292
- )
293
- ]
294
- )
295
- for tool_result in tool_results
296
- ]
297
-
298
- next_result = await self.chat(result.uuid, tool_messages, ChatPromptOptions(**dict(options)))
299
-
300
- return FinishedResult(**dict(next_result)) if next_result else None
241
+ async with self._client.request(
242
+ handler=RequestHandler.ToolResults,
243
+ body=ToolResultsRequestBody(
244
+ tool_call_id=tool_call.id,
245
+ result=tool_result.result,
246
+ is_error=tool_result.is_error,
247
+ ),
248
+ ):
249
+ pass
301
250
 
302
251
  async def get(self, path: str, options: Optional[GetPromptOptions] = None) -> GetPromptResult:
303
252
  options = GetPromptOptions(**{**dict(self._options), **dict(options or {})})
@@ -364,22 +313,16 @@ class Prompts:
364
313
  path=path,
365
314
  parameters=options.parameters,
366
315
  custom_identifier=options.custom_identifier,
316
+ tools=list(options.tools.keys()) if options.tools and options.stream else None,
367
317
  stream=options.stream,
368
318
  ),
319
+ stream=options.stream,
369
320
  ) as response:
370
321
  if options.stream:
371
- result = await self._handle_stream(response.sse(), options.on_event)
322
+ result = await self._handle_stream(response.sse(), options.on_event, options.tools)
372
323
  else:
373
324
  result = RunPromptResult.model_validate_json(response.content)
374
325
 
375
- if options.tools and result.tool_requests:
376
- try:
377
- # NOTE: The last sdk.chat called will already call on_finished
378
- final_result = await self._handle_tool_calls(result, options)
379
- return RunPromptResult(**dict(final_result)) if final_result else None
380
- except OnToolCallPaused:
381
- pass
382
-
383
326
  if options.on_finished:
384
327
  options.on_finished(FinishedResult(**dict(result)))
385
328
 
@@ -416,22 +359,16 @@ class Prompts:
416
359
  ),
417
360
  body=ChatPromptRequestBody(
418
361
  messages=messages,
362
+ tools=list(options.tools.keys()) if options.tools and options.stream else None,
419
363
  stream=options.stream,
420
364
  ),
365
+ stream=options.stream,
421
366
  ) as response:
422
367
  if options.stream:
423
- result = await self._handle_stream(response.sse(), options.on_event)
368
+ result = await self._handle_stream(response.sse(), options.on_event, options.tools)
424
369
  else:
425
370
  result = ChatPromptResult.model_validate_json(response.content)
426
371
 
427
- if options.tools and result.tool_requests:
428
- try:
429
- # NOTE: The last sdk.chat called will already call on_finished
430
- final_result = await self._handle_tool_calls(result, options)
431
- return ChatPromptResult(**dict(final_result)) if final_result else None
432
- except OnToolCallPaused:
433
- pass
434
-
435
372
  if options.on_finished:
436
373
  options.on_finished(FinishedResult(**dict(result)))
437
374