latitude-sdk 4.0.0b1__tar.gz → 5.0.0b1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/PKG-INFO +8 -8
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/README.md +7 -7
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/pyproject.toml +3 -1
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/scripts/test.py +1 -1
- latitude_sdk-5.0.0b1/src/latitude_sdk/__init__.py +7 -0
- latitude_sdk-5.0.0b1/src/latitude_sdk/client/client.py +147 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/payloads.py +3 -3
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/__init__.py +1 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/errors.py +1 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/evaluations.py +5 -14
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/latitude.py +1 -0
- latitude_sdk-5.0.0b1/src/latitude_sdk/sdk/projects.py +35 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/prompts.py +47 -83
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/types.py +46 -13
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/util/utils.py +5 -0
- latitude_sdk-5.0.0b1/src/latitude_sdk/version/__init__.py +1 -0
- latitude_sdk-5.0.0b1/src/latitude_sdk/version/version.py +63 -0
- latitude_sdk-5.0.0b1/tests/projects/create_test.py +53 -0
- latitude_sdk-5.0.0b1/tests/projects/get_all_test.py +57 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/chat_test.py +128 -9
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/run_test.py +187 -25
- latitude_sdk-5.0.0b1/tests/sdk/__init__.py +0 -0
- latitude_sdk-4.0.0b1/tests/test_acceptance.py → latitude_sdk-5.0.0b1/tests/sdk/acceptance_test.py +50 -39
- latitude_sdk-5.0.0b1/tests/sdk/client_test.py +110 -0
- latitude_sdk-5.0.0b1/tests/sdk/memory_test.py +376 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/utils/fixtures.py +43 -287
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/utils/utils.py +9 -2
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/uv.lock +18 -1
- latitude_sdk-4.0.0b1/src/latitude_sdk/__init__.py +0 -1
- latitude_sdk-4.0.0b1/src/latitude_sdk/client/client.py +0 -276
- latitude_sdk-4.0.0b1/src/latitude_sdk/sdk/projects.py +0 -41
- latitude_sdk-4.0.0b1/tests/projects/create_test.py +0 -101
- latitude_sdk-4.0.0b1/tests/projects/get_all_test.py +0 -137
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/.gitignore +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/.python-version +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/LICENSE.md +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/scripts/format.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/scripts/lint.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/__init__.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/client/router.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/env/__init__.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/env/env.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/py.typed +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/sdk/logs.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/src/latitude_sdk/util/__init__.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/__init__.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/evaluations/__init__.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/evaluations/annotate_test.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/logs/__init__.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/logs/create_test.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/projects/__init__.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/__init__.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/get_all_test.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/get_or_create_test.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/get_test.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/render_chain_test.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/prompts/render_test.py +0 -0
- {latitude_sdk-4.0.0b1 → latitude_sdk-5.0.0b1}/tests/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: latitude-sdk
|
3
|
-
Version:
|
3
|
+
Version: 5.0.0b1
|
4
4
|
Summary: Latitude SDK for Python
|
5
5
|
Project-URL: repository, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python
|
6
6
|
Project-URL: homepage, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python#readme
|
@@ -62,7 +62,13 @@ Requires uv `0.5.10` or higher.
|
|
62
62
|
|
63
63
|
### Running only a specific test
|
64
64
|
|
65
|
-
|
65
|
+
Specify the test inline:
|
66
|
+
|
67
|
+
```python
|
68
|
+
uv run scripts/test.py <test_path>::<test_case>::<test_name>
|
69
|
+
```
|
70
|
+
|
71
|
+
Or mark the test with an `only` marker:
|
66
72
|
|
67
73
|
```python
|
68
74
|
import pytest
|
@@ -78,12 +84,6 @@ async def my_test(self):
|
|
78
84
|
uv run scripts/test.py -m only
|
79
85
|
```
|
80
86
|
|
81
|
-
Another way is to specify the test in line:
|
82
|
-
|
83
|
-
```python
|
84
|
-
uv run scripts/test.py <test_path>::<test_case>::<test_name>
|
85
|
-
```
|
86
|
-
|
87
87
|
## License
|
88
88
|
|
89
89
|
The SDK is licensed under the [MIT License](https://opensource.org/licenses/MIT) - read the [LICENSE](/LICENSE) file for details.
|
@@ -43,7 +43,13 @@ Requires uv `0.5.10` or higher.
|
|
43
43
|
|
44
44
|
### Running only a specific test
|
45
45
|
|
46
|
-
|
46
|
+
Specify the test inline:
|
47
|
+
|
48
|
+
```python
|
49
|
+
uv run scripts/test.py <test_path>::<test_case>::<test_name>
|
50
|
+
```
|
51
|
+
|
52
|
+
Or mark the test with an `only` marker:
|
47
53
|
|
48
54
|
```python
|
49
55
|
import pytest
|
@@ -59,12 +65,6 @@ async def my_test(self):
|
|
59
65
|
uv run scripts/test.py -m only
|
60
66
|
```
|
61
67
|
|
62
|
-
Another way is to specify the test in line:
|
63
|
-
|
64
|
-
```python
|
65
|
-
uv run scripts/test.py <test_path>::<test_case>::<test_name>
|
66
|
-
```
|
67
|
-
|
68
68
|
## License
|
69
69
|
|
70
70
|
The SDK is licensed under the [MIT License](https://opensource.org/licenses/MIT) - read the [LICENSE](/LICENSE) file for details.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "latitude-sdk"
|
3
|
-
version = "
|
3
|
+
version = "5.0.0-beta.1"
|
4
4
|
description = "Latitude SDK for Python"
|
5
5
|
authors = [{ name = "Latitude Data SL", email = "hello@latitude.so" }]
|
6
6
|
maintainers = [{ name = "Latitude Data SL", email = "hello@latitude.so" }]
|
@@ -27,6 +27,7 @@ dev = [
|
|
27
27
|
"pyright>=1.1.401",
|
28
28
|
"ruff>=0.8.3",
|
29
29
|
"sh>=1.14.3",
|
30
|
+
"psutil>=7.0.0",
|
30
31
|
]
|
31
32
|
|
32
33
|
[tool.pyright]
|
@@ -35,6 +36,7 @@ typeCheckingMode = "strict"
|
|
35
36
|
reportMissingTypeStubs = false
|
36
37
|
reportUnnecessaryIsInstance = false
|
37
38
|
reportPrivateUsage = false
|
39
|
+
reportPrivateImportUsage = false
|
38
40
|
|
39
41
|
[tool.ruff]
|
40
42
|
target-version = "py39"
|
@@ -0,0 +1,147 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
from contextlib import asynccontextmanager
|
4
|
+
from typing import Any, AsyncGenerator, Optional
|
5
|
+
|
6
|
+
import httpx
|
7
|
+
import httpx_sse
|
8
|
+
|
9
|
+
from latitude_sdk.client.payloads import ErrorResponse, RequestBody, RequestHandler, RequestParams
|
10
|
+
from latitude_sdk.client.router import Router, RouterOptions
|
11
|
+
from latitude_sdk.sdk.errors import ApiError, ApiErrorCodes, ApiErrorDbRef
|
12
|
+
from latitude_sdk.sdk.types import LogSources
|
13
|
+
from latitude_sdk.util import Model
|
14
|
+
from latitude_sdk.version import version
|
15
|
+
|
16
|
+
RETRIABLE_STATUSES = [408, 429, 500, 502, 503, 504]
|
17
|
+
|
18
|
+
ClientEvent = httpx_sse.ServerSentEvent
|
19
|
+
|
20
|
+
|
21
|
+
class ClientResponse(httpx.Response):
|
22
|
+
async def sse(self: httpx.Response) -> AsyncGenerator[ClientEvent, Any]:
|
23
|
+
source = httpx_sse.EventSource(self)
|
24
|
+
|
25
|
+
async for event in source.aiter_sse():
|
26
|
+
yield event
|
27
|
+
|
28
|
+
|
29
|
+
httpx.Response.sse = ClientResponse.sse # pyright: ignore [reportAttributeAccessIssue]
|
30
|
+
|
31
|
+
|
32
|
+
class ClientOptions(Model):
|
33
|
+
api_key: str
|
34
|
+
retries: int
|
35
|
+
delay: float
|
36
|
+
timeout: Optional[float]
|
37
|
+
source: LogSources
|
38
|
+
router: RouterOptions
|
39
|
+
|
40
|
+
|
41
|
+
class Client:
|
42
|
+
options: ClientOptions
|
43
|
+
router: Router
|
44
|
+
|
45
|
+
def __init__(self, options: ClientOptions):
|
46
|
+
self.options = options
|
47
|
+
self.router = Router(options.router)
|
48
|
+
|
49
|
+
@asynccontextmanager
|
50
|
+
async def request(
|
51
|
+
self,
|
52
|
+
handler: RequestHandler,
|
53
|
+
params: Optional[RequestParams] = None,
|
54
|
+
body: Optional[RequestBody] = None,
|
55
|
+
stream: Optional[bool] = None,
|
56
|
+
) -> AsyncGenerator[ClientResponse, Any]:
|
57
|
+
client = httpx.AsyncClient(
|
58
|
+
headers={
|
59
|
+
"Authorization": f"Bearer {self.options.api_key}",
|
60
|
+
"X-Latitude-SDK-Version": version.semver,
|
61
|
+
"Content-Type": "application/json",
|
62
|
+
"Accept": "text/event-stream" if stream else "application/json",
|
63
|
+
},
|
64
|
+
timeout=self.options.timeout,
|
65
|
+
follow_redirects=False,
|
66
|
+
max_redirects=0,
|
67
|
+
)
|
68
|
+
response = None
|
69
|
+
attempt = 1
|
70
|
+
|
71
|
+
try:
|
72
|
+
method, url = self.router.resolve(handler, params)
|
73
|
+
content = None
|
74
|
+
if body:
|
75
|
+
content = json.dumps(
|
76
|
+
{
|
77
|
+
**json.loads(body.model_dump_json()),
|
78
|
+
"__internal": {"source": self.options.source},
|
79
|
+
}
|
80
|
+
)
|
81
|
+
|
82
|
+
while attempt <= self.options.retries:
|
83
|
+
try:
|
84
|
+
request = client.build_request(method=method, url=url, content=content)
|
85
|
+
response = await client.send(request=request, stream=stream or False)
|
86
|
+
response.raise_for_status()
|
87
|
+
|
88
|
+
yield response # pyright: ignore [reportReturnType]
|
89
|
+
break
|
90
|
+
|
91
|
+
except Exception as exception:
|
92
|
+
if isinstance(exception, ApiError):
|
93
|
+
raise exception
|
94
|
+
|
95
|
+
if attempt >= self.options.retries:
|
96
|
+
raise await self._exception(exception, response) from exception
|
97
|
+
|
98
|
+
if response and response.status_code in RETRIABLE_STATUSES:
|
99
|
+
await asyncio.sleep(self.options.delay * (2 ** (attempt - 1)))
|
100
|
+
else:
|
101
|
+
raise await self._exception(exception, response) from exception
|
102
|
+
|
103
|
+
finally:
|
104
|
+
if response:
|
105
|
+
await response.aclose()
|
106
|
+
|
107
|
+
attempt += 1
|
108
|
+
|
109
|
+
except Exception as exception:
|
110
|
+
if isinstance(exception, ApiError):
|
111
|
+
raise exception
|
112
|
+
|
113
|
+
raise await self._exception(exception, response) from exception
|
114
|
+
|
115
|
+
finally:
|
116
|
+
await client.aclose()
|
117
|
+
|
118
|
+
async def _exception(self, exception: Exception, response: Optional[httpx.Response] = None) -> ApiError:
|
119
|
+
if not response:
|
120
|
+
return ApiError(
|
121
|
+
status=500,
|
122
|
+
code=ApiErrorCodes.InternalServerError,
|
123
|
+
message=str(exception),
|
124
|
+
response=str(exception),
|
125
|
+
)
|
126
|
+
|
127
|
+
try:
|
128
|
+
if not response.is_stream_consumed:
|
129
|
+
await response.aread()
|
130
|
+
|
131
|
+
error = ErrorResponse.model_validate_json(response.content)
|
132
|
+
|
133
|
+
return ApiError(
|
134
|
+
status=response.status_code,
|
135
|
+
code=error.code,
|
136
|
+
message=error.message,
|
137
|
+
response=response.text,
|
138
|
+
db_ref=ApiErrorDbRef(**dict(error.db_ref)) if error.db_ref else None,
|
139
|
+
)
|
140
|
+
|
141
|
+
except Exception:
|
142
|
+
return ApiError(
|
143
|
+
status=response.status_code,
|
144
|
+
code=ApiErrorCodes.InternalServerError,
|
145
|
+
message=str(exception),
|
146
|
+
response=response.text,
|
147
|
+
)
|
@@ -44,8 +44,8 @@ class RunPromptRequestBody(Model):
|
|
44
44
|
path: str
|
45
45
|
parameters: Optional[Dict[str, Any]] = None
|
46
46
|
custom_identifier: Optional[str] = Field(default=None, alias=str("customIdentifier"))
|
47
|
-
stream: Optional[bool] = None
|
48
47
|
tools: Optional[List[str]] = None
|
48
|
+
stream: Optional[bool] = None
|
49
49
|
|
50
50
|
|
51
51
|
class ChatPromptRequestParams(Model):
|
@@ -54,6 +54,7 @@ class ChatPromptRequestParams(Model):
|
|
54
54
|
|
55
55
|
class ChatPromptRequestBody(Model):
|
56
56
|
messages: List[Message]
|
57
|
+
tools: Optional[List[str]] = None
|
57
58
|
stream: Optional[bool] = None
|
58
59
|
|
59
60
|
|
@@ -81,11 +82,10 @@ class AnnotateEvaluationRequestParams(EvaluationRequestParams, Model):
|
|
81
82
|
|
82
83
|
|
83
84
|
class AnnotateEvaluationRequestBody(Model):
|
84
|
-
score: int
|
85
|
-
|
86
85
|
class Metadata(Model):
|
87
86
|
reason: str
|
88
87
|
|
88
|
+
score: int
|
89
89
|
metadata: Optional[Metadata] = None
|
90
90
|
|
91
91
|
|
@@ -1,5 +1,4 @@
|
|
1
|
-
from
|
2
|
-
from typing import Any, Optional, Union
|
1
|
+
from typing import Optional
|
3
2
|
|
4
3
|
from latitude_sdk.client import (
|
5
4
|
AnnotateEvaluationRequestBody,
|
@@ -7,24 +6,16 @@ from latitude_sdk.client import (
|
|
7
6
|
Client,
|
8
7
|
RequestHandler,
|
9
8
|
)
|
10
|
-
from latitude_sdk.sdk.types import SdkOptions
|
11
|
-
from latitude_sdk.util import
|
9
|
+
from latitude_sdk.sdk.types import EvaluationResult, SdkOptions
|
10
|
+
from latitude_sdk.util import Model
|
12
11
|
|
13
12
|
|
14
13
|
class AnnotateEvaluationOptions(Model):
|
15
14
|
reason: str
|
16
15
|
|
17
16
|
|
18
|
-
class AnnotateEvaluationResult(Model):
|
19
|
-
|
20
|
-
score: int
|
21
|
-
normalized_score: int = Field(alias=str("normalizedScore"))
|
22
|
-
metadata: dict[str, Any]
|
23
|
-
has_passed: bool = Field(alias=str("hasPassed"))
|
24
|
-
created_at: datetime = Field(alias=str("createdAt"))
|
25
|
-
updated_at: datetime = Field(alias=str("updatedAt"))
|
26
|
-
version_uuid: str = Field(alias=str("versionUuid"))
|
27
|
-
error: Optional[Union[str, None]] = None
|
17
|
+
class AnnotateEvaluationResult(EvaluationResult, Model):
|
18
|
+
pass
|
28
19
|
|
29
20
|
|
30
21
|
class Evaluations:
|
@@ -0,0 +1,35 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from latitude_sdk.client import Client, CreateProjectRequestBody, RequestHandler
|
4
|
+
from latitude_sdk.sdk.types import Project, SdkOptions, Version
|
5
|
+
from latitude_sdk.util import Adapter as AdapterUtil
|
6
|
+
from latitude_sdk.util import Model
|
7
|
+
|
8
|
+
_GetAllProjectResults = AdapterUtil[List[Project]](List[Project])
|
9
|
+
|
10
|
+
|
11
|
+
class CreateProjectResult(Model):
|
12
|
+
project: Project
|
13
|
+
version: Version
|
14
|
+
|
15
|
+
|
16
|
+
class Projects:
|
17
|
+
_options: SdkOptions
|
18
|
+
_client: Client
|
19
|
+
|
20
|
+
def __init__(self, client: Client, options: SdkOptions):
|
21
|
+
self._options = options
|
22
|
+
self._client = client
|
23
|
+
|
24
|
+
async def get_all(self) -> List[Project]:
|
25
|
+
async with self._client.request(
|
26
|
+
handler=RequestHandler.GetAllProjects,
|
27
|
+
) as response:
|
28
|
+
return _GetAllProjectResults.validate_json(response.content)
|
29
|
+
|
30
|
+
async def create(self, name: str) -> CreateProjectResult:
|
31
|
+
async with self._client.request(
|
32
|
+
handler=RequestHandler.CreateProject,
|
33
|
+
body=CreateProjectRequestBody(name=name),
|
34
|
+
) as response:
|
35
|
+
return CreateProjectResult.model_validate_json(response.content)
|
@@ -1,11 +1,6 @@
|
|
1
|
-
from typing import Any, AsyncGenerator,
|
1
|
+
from typing import Any, AsyncGenerator, List, Optional, Sequence
|
2
2
|
|
3
|
-
from promptl_ai import
|
4
|
-
Adapter,
|
5
|
-
Message,
|
6
|
-
MessageLike,
|
7
|
-
Promptl,
|
8
|
-
)
|
3
|
+
from promptl_ai import Adapter, Message, MessageLike, Promptl
|
9
4
|
from promptl_ai.bindings.types import _Message
|
10
5
|
|
11
6
|
from latitude_sdk.client import (
|
@@ -30,6 +25,8 @@ from latitude_sdk.sdk.types import (
|
|
30
25
|
OnToolCall,
|
31
26
|
OnToolCallDetails,
|
32
27
|
Prompt,
|
28
|
+
ProviderEvents,
|
29
|
+
ProviderEventToolCalled,
|
33
30
|
Providers,
|
34
31
|
SdkOptions,
|
35
32
|
StreamCallbacks,
|
@@ -145,7 +142,7 @@ class Prompts:
|
|
145
142
|
self,
|
146
143
|
stream: AsyncGenerator[ClientEvent, Any],
|
147
144
|
on_event: Optional[StreamCallbacks.OnEvent],
|
148
|
-
|
145
|
+
tools: Optional[dict[str, OnToolCall]],
|
149
146
|
) -> FinishedResult:
|
150
147
|
uuid = None
|
151
148
|
conversation: List[Message] = []
|
@@ -153,6 +150,7 @@ class Prompts:
|
|
153
150
|
|
154
151
|
async for stream_event in stream:
|
155
152
|
event = None
|
153
|
+
tool_call = None
|
156
154
|
|
157
155
|
if stream_event.event == str(StreamEvents.Latitude):
|
158
156
|
event = _LatitudeEvent.validate_json(stream_event.data)
|
@@ -174,9 +172,8 @@ class Prompts:
|
|
174
172
|
event = stream_event.json()
|
175
173
|
event["event"] = StreamEvents.Provider
|
176
174
|
|
177
|
-
|
178
|
-
|
179
|
-
await on_tool_call(event)
|
175
|
+
if event.get("type") == str(ProviderEvents.ToolCalled):
|
176
|
+
tool_call = ProviderEventToolCalled.model_validate_json(stream_event.data)
|
180
177
|
|
181
178
|
else:
|
182
179
|
raise ApiError(
|
@@ -189,6 +186,9 @@ class Prompts:
|
|
189
186
|
if on_event:
|
190
187
|
on_event(event)
|
191
188
|
|
189
|
+
if tool_call:
|
190
|
+
await self._handle_tool_call(tool_call, tools)
|
191
|
+
|
192
192
|
if not uuid or not response:
|
193
193
|
raise ApiError(
|
194
194
|
status=500,
|
@@ -197,11 +197,7 @@ class Prompts:
|
|
197
197
|
response="Stream ended without a chain-complete event. Missing uuid or response.",
|
198
198
|
)
|
199
199
|
|
200
|
-
return FinishedResult(
|
201
|
-
uuid=uuid,
|
202
|
-
conversation=conversation,
|
203
|
-
response=response,
|
204
|
-
)
|
200
|
+
return FinishedResult(uuid=uuid, conversation=conversation, response=response)
|
205
201
|
|
206
202
|
@staticmethod
|
207
203
|
async def _wrap_tool_handler(
|
@@ -212,66 +208,45 @@ class Prompts:
|
|
212
208
|
try:
|
213
209
|
result = await handler(arguments, details)
|
214
210
|
|
215
|
-
return ToolResult(**tool_result, result=result)
|
211
|
+
return ToolResult(**tool_result, result=result, is_error=False)
|
216
212
|
except Exception as exception:
|
217
213
|
return ToolResult(**tool_result, result=str(exception), is_error=True)
|
218
214
|
|
219
215
|
async def _handle_tool_call(
|
220
|
-
self,
|
221
|
-
event: dict[str, Any],
|
222
|
-
tools: dict[str, OnToolCall],
|
216
|
+
self, tool_call: ProviderEventToolCalled, tools: Optional[dict[str, OnToolCall]]
|
223
217
|
) -> None:
|
224
|
-
|
225
|
-
|
226
|
-
args: dict[str, Any] = event["args"]
|
227
|
-
|
228
|
-
tool = tools.get(toolName)
|
229
|
-
if not tool:
|
218
|
+
# NOTE: Do not handle tool calls if user specified no tools
|
219
|
+
if not tools:
|
230
220
|
return
|
231
221
|
|
222
|
+
tool_handler = tools.get(tool_call.name)
|
223
|
+
if not tool_handler:
|
224
|
+
raise ApiError(
|
225
|
+
status=400,
|
226
|
+
code=ApiErrorCodes.AIRunError,
|
227
|
+
message=f"Tool {tool_call.name} not supplied",
|
228
|
+
response=f"Tool {tool_call.name} not supplied",
|
229
|
+
)
|
230
|
+
|
232
231
|
tool_result = await self._wrap_tool_handler(
|
233
|
-
|
234
|
-
|
232
|
+
tool_handler,
|
233
|
+
tool_call.arguments,
|
235
234
|
OnToolCallDetails(
|
236
|
-
id=
|
237
|
-
name=
|
238
|
-
arguments=
|
235
|
+
id=tool_call.id,
|
236
|
+
name=tool_call.name,
|
237
|
+
arguments=tool_call.arguments,
|
239
238
|
),
|
240
239
|
)
|
241
240
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
) as _:
|
252
|
-
pass
|
253
|
-
except Exception as exception:
|
254
|
-
if not isinstance(exception, ApiError):
|
255
|
-
exception = ApiError(
|
256
|
-
status=500,
|
257
|
-
code=ApiErrorCodes.InternalServerError,
|
258
|
-
message=str(exception),
|
259
|
-
response=str(exception),
|
260
|
-
)
|
261
|
-
|
262
|
-
# Add context about which tool failed
|
263
|
-
message = f"Failed to execute tool {toolName}. \nLatitude API returned the following error:\
|
264
|
-
\n\n{exception.message}"
|
265
|
-
|
266
|
-
raise ApiError(
|
267
|
-
status=exception.status,
|
268
|
-
code=exception.code,
|
269
|
-
message=message,
|
270
|
-
response=exception.response,
|
271
|
-
) from exception
|
272
|
-
|
273
|
-
def _on_tool_call(self, tools: dict[str, OnToolCall]) -> Callable[[dict[str, Any]], Any]:
|
274
|
-
return lambda event: self._handle_tool_call(event, tools)
|
241
|
+
async with self._client.request(
|
242
|
+
handler=RequestHandler.ToolResults,
|
243
|
+
body=ToolResultsRequestBody(
|
244
|
+
tool_call_id=tool_call.id,
|
245
|
+
result=tool_result.result,
|
246
|
+
is_error=tool_result.is_error,
|
247
|
+
),
|
248
|
+
):
|
249
|
+
pass
|
275
250
|
|
276
251
|
async def get(self, path: str, options: Optional[GetPromptOptions] = None) -> GetPromptResult:
|
277
252
|
options = GetPromptOptions(**{**dict(self._options), **dict(options or {})})
|
@@ -338,16 +313,13 @@ class Prompts:
|
|
338
313
|
path=path,
|
339
314
|
parameters=options.parameters,
|
340
315
|
custom_identifier=options.custom_identifier,
|
341
|
-
stream=options.stream,
|
342
316
|
tools=list(options.tools.keys()) if options.tools and options.stream else None,
|
317
|
+
stream=options.stream,
|
343
318
|
),
|
319
|
+
stream=options.stream,
|
344
320
|
) as response:
|
345
321
|
if options.stream:
|
346
|
-
result = await self._handle_stream(
|
347
|
-
response.sse(),
|
348
|
-
options.on_event,
|
349
|
-
self._on_tool_call(options.tools if options.tools else {}),
|
350
|
-
)
|
322
|
+
result = await self._handle_stream(response.sse(), options.on_event, options.tools)
|
351
323
|
else:
|
352
324
|
result = RunPromptResult.model_validate_json(response.content)
|
353
325
|
|
@@ -373,10 +345,7 @@ class Prompts:
|
|
373
345
|
return None
|
374
346
|
|
375
347
|
async def chat(
|
376
|
-
self,
|
377
|
-
uuid: str,
|
378
|
-
messages: Sequence[MessageLike],
|
379
|
-
options: Optional[ChatPromptOptions] = None,
|
348
|
+
self, uuid: str, messages: Sequence[MessageLike], options: Optional[ChatPromptOptions] = None
|
380
349
|
) -> Optional[ChatPromptResult]:
|
381
350
|
options = ChatPromptOptions(**{**dict(self._options), **dict(options or {})})
|
382
351
|
|
@@ -390,15 +359,13 @@ class Prompts:
|
|
390
359
|
),
|
391
360
|
body=ChatPromptRequestBody(
|
392
361
|
messages=messages,
|
362
|
+
tools=list(options.tools.keys()) if options.tools and options.stream else None,
|
393
363
|
stream=options.stream,
|
394
364
|
),
|
365
|
+
stream=options.stream,
|
395
366
|
) as response:
|
396
367
|
if options.stream:
|
397
|
-
result = await self._handle_stream(
|
398
|
-
response.sse(),
|
399
|
-
options.on_event,
|
400
|
-
self._on_tool_call(options.tools if options.tools else {}),
|
401
|
-
)
|
368
|
+
result = await self._handle_stream(response.sse(), options.on_event, options.tools)
|
402
369
|
else:
|
403
370
|
result = ChatPromptResult.model_validate_json(response.content)
|
404
371
|
|
@@ -451,10 +418,7 @@ class Prompts:
|
|
451
418
|
)
|
452
419
|
|
453
420
|
async def render_chain(
|
454
|
-
self,
|
455
|
-
prompt: Prompt,
|
456
|
-
on_step: OnStep,
|
457
|
-
options: Optional[RenderChainOptions] = None,
|
421
|
+
self, prompt: Prompt, on_step: OnStep, options: Optional[RenderChainOptions] = None
|
458
422
|
) -> RenderChainResult:
|
459
423
|
options = RenderChainOptions(**{**dict(self._options), **dict(options or {})})
|
460
424
|
adapter = options.adapter or _PROVIDER_TO_ADAPTER.get(prompt.provider or Providers.OpenAI, Adapter.OpenAI)
|