latitude-sdk 1.0.2__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/PKG-INFO +23 -1
  2. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/README.md +22 -0
  3. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/pyproject.toml +1 -1
  4. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/client/payloads.py +6 -0
  5. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/client/router.py +11 -0
  6. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/sdk/latitude.py +1 -1
  7. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/sdk/prompts.py +63 -56
  8. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/sdk/types.py +97 -45
  9. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/prompts/chat_test.py +28 -28
  10. latitude_sdk-1.1.0/tests/prompts/get_all_test.py +62 -0
  11. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/prompts/run_test.py +50 -50
  12. latitude_sdk-1.1.0/tests/utils/fixtures.py +1506 -0
  13. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/utils/utils.py +2 -2
  14. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/uv.lock +232 -229
  15. latitude_sdk-1.0.2/tests/utils/fixtures.py +0 -998
  16. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/.gitignore +0 -0
  17. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/.python-version +0 -0
  18. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/scripts/format.py +0 -0
  19. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/scripts/lint.py +0 -0
  20. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/scripts/test.py +0 -0
  21. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/__init__.py +0 -0
  22. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/client/__init__.py +0 -0
  23. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/client/client.py +0 -0
  24. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/env/__init__.py +0 -0
  25. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/env/env.py +0 -0
  26. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/py.typed +0 -0
  27. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/sdk/__init__.py +0 -0
  28. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/sdk/errors.py +0 -0
  29. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/sdk/evaluations.py +0 -0
  30. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/sdk/logs.py +0 -0
  31. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/util/__init__.py +0 -0
  32. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/src/latitude_sdk/util/utils.py +0 -0
  33. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/__init__.py +0 -0
  34. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/evaluations/__init__.py +0 -0
  35. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/evaluations/create_result_test.py +0 -0
  36. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/evaluations/trigger_test.py +0 -0
  37. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/logs/__init__.py +0 -0
  38. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/logs/create_test.py +0 -0
  39. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/prompts/__init__.py +0 -0
  40. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/prompts/get_or_create_test.py +0 -0
  41. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/prompts/get_test.py +0 -0
  42. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/prompts/render_chain_test.py +0 -0
  43. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/prompts/render_test.py +0 -0
  44. {latitude_sdk-1.0.2 → latitude_sdk-1.1.0}/tests/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: latitude-sdk
3
- Version: 1.0.2
3
+ Version: 1.1.0
4
4
  Summary: Latitude SDK for Python
5
5
  Project-URL: repository, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python
6
6
  Project-URL: homepage, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python#readme
@@ -60,6 +60,28 @@ Requires uv `0.5.10` or higher.
60
60
  - Build package: `uv build`
61
61
  - Publish package: `uv publish`
62
62
 
63
+ ## Run only one test
64
+
65
+ ```python
66
+ import pytest
67
+
68
+ @pytest.mark.only
69
+ async def my_test(self):
70
+ # ... your code
71
+ ```
72
+
73
+ And then run the tests with the marker `only`:
74
+
75
+ ```sh
76
+ uv run scripts/test.py -m only
77
+ ```
78
+
79
+ Other way is all in line:
80
+
81
+ ```python
82
+ uv run scripts/test.py <test_path>::<test_case>::<test_name>
83
+ ```
84
+
63
85
  ## License
64
86
 
65
87
  The SDK is licensed under the [LGPL-3.0 License](https://opensource.org/licenses/LGPL-3.0) - read the [LICENSE](/LICENSE) file for details.
@@ -41,6 +41,28 @@ Requires uv `0.5.10` or higher.
41
41
  - Build package: `uv build`
42
42
  - Publish package: `uv publish`
43
43
 
44
+ ## Run only one test
45
+
46
+ ```python
47
+ import pytest
48
+
49
+ @pytest.mark.only
50
+ async def my_test(self):
51
+ # ... your code
52
+ ```
53
+
54
+ And then run the tests with the marker `only`:
55
+
56
+ ```sh
57
+ uv run scripts/test.py -m only
58
+ ```
59
+
60
+ Other way is all in line:
61
+
62
+ ```python
63
+ uv run scripts/test.py <test_path>::<test_case>::<test_name>
64
+ ```
65
+
44
66
  ## License
45
67
 
46
68
  The SDK is licensed under the [LGPL-3.0 License](https://opensource.org/licenses/LGPL-3.0) - read the [LICENSE](/LICENSE) file for details.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "latitude-sdk"
3
- version = "1.0.2"
3
+ version = "1.1.0"
4
4
  description = "Latitude SDK for Python"
5
5
  authors = [{ name = "Latitude Data SL", email = "hello@latitude.so" }]
6
6
  maintainers = [{ name = "Latitude Data SL", email = "hello@latitude.so" }]
@@ -23,6 +23,10 @@ class GetPromptRequestParams(PromptRequestParams, Model):
23
23
  path: str
24
24
 
25
25
 
26
+ class GetAllPromptRequestParams(PromptRequestParams, Model):
27
+ pass
28
+
29
+
26
30
  class GetOrCreatePromptRequestParams(PromptRequestParams, Model):
27
31
  pass
28
32
 
@@ -90,6 +94,7 @@ class CreateEvaluationResultRequestBody(Model):
90
94
 
91
95
  RequestParams = Union[
92
96
  GetPromptRequestParams,
97
+ GetAllPromptRequestParams,
93
98
  GetOrCreatePromptRequestParams,
94
99
  RunPromptRequestParams,
95
100
  ChatPromptRequestParams,
@@ -111,6 +116,7 @@ RequestBody = Union[
111
116
 
112
117
  class RequestHandler(StrEnum):
113
118
  GetPrompt = "GET_PROMPT"
119
+ GetAllPrompts = "GET_ALL_PROMPTS"
114
120
  GetOrCreatePrompt = "GET_OR_CREATE_PROMPT"
115
121
  RunPrompt = "RUN_PROMPT"
116
122
  ChatPrompt = "CHAT_PROMPT"
@@ -4,6 +4,7 @@ from latitude_sdk.client.payloads import (
4
4
  ChatPromptRequestParams,
5
5
  CreateEvaluationResultRequestParams,
6
6
  CreateLogRequestParams,
7
+ GetAllPromptRequestParams,
7
8
  GetOrCreatePromptRequestParams,
8
9
  GetPromptRequestParams,
9
10
  RequestHandler,
@@ -36,6 +37,14 @@ class Router:
36
37
  version_uuid=params.version_uuid,
37
38
  ).prompt(params.path)
38
39
 
40
+ if handler == RequestHandler.GetAllPrompts:
41
+ assert isinstance(params, GetAllPromptRequestParams)
42
+
43
+ return "GET", self.prompts(
44
+ project_id=params.project_id,
45
+ version_uuid=params.version_uuid,
46
+ ).all_prompts
47
+
39
48
  elif handler == RequestHandler.GetOrCreatePrompt:
40
49
  assert isinstance(params, GetOrCreatePromptRequestParams)
41
50
 
@@ -94,6 +103,7 @@ class Router:
94
103
 
95
104
  class Prompts(Model):
96
105
  prompt: Callable[[str], str]
106
+ all_prompts: str
97
107
  get_or_create: str
98
108
  run: str
99
109
  logs: str
@@ -102,6 +112,7 @@ class Router:
102
112
  base_url = f"{self.commits_url(project_id, version_uuid)}/documents"
103
113
 
104
114
  return self.Prompts(
115
+ all_prompts=f"{base_url}",
105
116
  prompt=lambda path: f"{base_url}/{path}",
106
117
  get_or_create=f"{base_url}/get-or-create",
107
118
  run=f"{base_url}/run",
@@ -31,7 +31,7 @@ DEFAULT_INTERNAL_OPTIONS = InternalOptions(
31
31
  host=env.GATEWAY_HOSTNAME,
32
32
  port=env.GATEWAY_PORT,
33
33
  ssl=env.GATEWAY_SSL,
34
- api_version="v2",
34
+ api_version="v3",
35
35
  ),
36
36
  source=LogSources.Api,
37
37
  retries=3,
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, Union
2
+ from typing import Any, AsyncGenerator, List, Optional, Sequence, Union
3
3
 
4
4
  from promptl_ai import Adapter, Message, MessageLike, Promptl, ToolMessage, ToolResultContent
5
5
  from promptl_ai.bindings.types import _Message
@@ -9,6 +9,7 @@ from latitude_sdk.client import (
9
9
  ChatPromptRequestParams,
10
10
  Client,
11
11
  ClientEvent,
12
+ GetAllPromptRequestParams,
12
13
  GetOrCreatePromptRequestBody,
13
14
  GetOrCreatePromptRequestParams,
14
15
  GetPromptRequestParams,
@@ -18,12 +19,8 @@ from latitude_sdk.client import (
18
19
  )
19
20
  from latitude_sdk.sdk.errors import ApiError, ApiErrorCodes
20
21
  from latitude_sdk.sdk.types import (
21
- ChainEventCompleted,
22
- ChainEventError,
23
22
  ChainEvents,
24
- ChainEventStep,
25
- ChainEventStepCompleted,
26
- FinishedEvent,
23
+ FinishedResult,
27
24
  OnStep,
28
25
  OnToolCall,
29
26
  OnToolCallDetails,
@@ -33,8 +30,11 @@ from latitude_sdk.sdk.types import (
33
30
  StreamCallbacks,
34
31
  StreamEvents,
35
32
  StreamTypes,
33
+ ToolCall,
36
34
  ToolResult,
35
+ _LatitudeEvent,
37
36
  )
37
+ from latitude_sdk.util import Adapter as AdapterUtil
38
38
  from latitude_sdk.util import Model
39
39
 
40
40
  _PROVIDER_TO_ADAPTER = {
@@ -69,6 +69,13 @@ class GetPromptResult(Prompt, Model):
69
69
  pass
70
70
 
71
71
 
72
+ _GetAllPromptResults = AdapterUtil[List[GetPromptResult]](List[GetPromptResult])
73
+
74
+
75
+ class GetAllPromptOptions(PromptOptions, Model):
76
+ pass
77
+
78
+
72
79
  class GetOrCreatePromptOptions(PromptOptions, Model):
73
80
  prompt: Optional[str] = None
74
81
 
@@ -79,36 +86,36 @@ class GetOrCreatePromptResult(Prompt, Model):
79
86
 
80
87
  class RunPromptOptions(StreamCallbacks, PromptOptions, Model):
81
88
  custom_identifier: Optional[str] = None
82
- parameters: Optional[Dict[str, Any]] = None
83
- tools: Optional[Dict[str, OnToolCall]] = None
89
+ parameters: Optional[dict[str, Any]] = None
90
+ tools: Optional[dict[str, OnToolCall]] = None
84
91
  stream: Optional[bool] = None
85
92
 
86
93
 
87
- class RunPromptResult(FinishedEvent, Model):
94
+ class RunPromptResult(FinishedResult, Model):
88
95
  pass
89
96
 
90
97
 
91
98
  class ChatPromptOptions(StreamCallbacks, Model):
92
- tools: Optional[Dict[str, OnToolCall]] = None
99
+ tools: Optional[dict[str, OnToolCall]] = None
93
100
  stream: Optional[bool] = None
94
101
 
95
102
 
96
- class ChatPromptResult(FinishedEvent, Model):
103
+ class ChatPromptResult(FinishedResult, Model):
97
104
  pass
98
105
 
99
106
 
100
107
  class RenderPromptOptions(Model):
101
- parameters: Optional[Dict[str, Any]] = None
108
+ parameters: Optional[dict[str, Any]] = None
102
109
  adapter: Optional[Adapter] = None
103
110
 
104
111
 
105
112
  class RenderPromptResult(Model):
106
- messages: List[MessageLike]
107
- config: Dict[str, Any]
113
+ messages: list[MessageLike]
114
+ config: dict[str, Any]
108
115
 
109
116
 
110
117
  class RenderChainOptions(Model):
111
- parameters: Optional[Dict[str, Any]] = None
118
+ parameters: Optional[dict[str, Any]] = None
112
119
  adapter: Optional[Adapter] = None
113
120
 
114
121
 
@@ -137,32 +144,27 @@ class Prompts:
137
144
 
138
145
  async def _handle_stream(
139
146
  self, stream: AsyncGenerator[ClientEvent, Any], on_event: Optional[StreamCallbacks.OnEvent]
140
- ) -> FinishedEvent:
147
+ ) -> FinishedResult:
141
148
  uuid = None
142
- conversation: List[Message] = []
149
+ conversation: list[Message] = []
143
150
  response = None
151
+ tool_requests: list[ToolCall] = []
144
152
 
145
153
  async for stream_event in stream:
146
154
  event = None
147
155
 
148
156
  if stream_event.event == str(StreamEvents.Latitude):
149
- type = stream_event.json().get("type")
150
-
151
- if type == str(ChainEvents.Step):
152
- event = ChainEventStep.model_validate_json(stream_event.data)
153
- conversation.extend(event.messages)
157
+ event = _LatitudeEvent.validate_json(stream_event.data)
158
+ conversation = event.messages
159
+ uuid = event.uuid
154
160
 
155
- elif type == str(ChainEvents.StepCompleted):
156
- event = ChainEventStepCompleted.model_validate_json(stream_event.data)
157
-
158
- elif type == str(ChainEvents.Completed):
159
- event = ChainEventCompleted.model_validate_json(stream_event.data)
160
- uuid = event.uuid
161
- conversation.extend(event.messages or [])
161
+ if event.type == ChainEvents.ProviderCompleted:
162
162
  response = event.response
163
163
 
164
- elif type == str(ChainEvents.Error):
165
- event = ChainEventError.model_validate_json(stream_event.data)
164
+ elif event.type == ChainEvents.ToolsRequested:
165
+ tool_requests = event.tools
166
+
167
+ elif event.type == ChainEvents.ChainError:
166
168
  raise ApiError(
167
169
  status=400,
168
170
  code=ApiErrorCodes.AIRunError,
@@ -170,14 +172,6 @@ class Prompts:
170
172
  response=stream_event.data,
171
173
  )
172
174
 
173
- else:
174
- raise ApiError(
175
- status=500,
176
- code=ApiErrorCodes.InternalServerError,
177
- message=f"Unknown latitude event: {type}",
178
- response=stream_event.data,
179
- )
180
-
181
175
  elif stream_event.event == str(StreamEvents.Provider):
182
176
  event = stream_event.json()
183
177
  event["event"] = StreamEvents.Provider
@@ -201,8 +195,7 @@ class Prompts:
201
195
  response="Stream ended without a chain-complete event. Missing uuid or response.",
202
196
  )
203
197
 
204
- # NOTE: FinishedEvent not in on_event
205
- return FinishedEvent(uuid=uuid, conversation=conversation, response=response)
198
+ return FinishedResult(uuid=uuid, conversation=conversation, response=response, tool_requests=tool_requests)
206
199
 
207
200
  @staticmethod
208
201
  def _pause_tool_execution() -> Any:
@@ -210,9 +203,9 @@ class Prompts:
210
203
 
211
204
  @staticmethod
212
205
  async def _wrap_tool_handler(
213
- handler: OnToolCall, arguments: Dict[str, Any], details: OnToolCallDetails
206
+ handler: OnToolCall, arguments: dict[str, Any], details: OnToolCallDetails
214
207
  ) -> ToolResult:
215
- tool_result: Dict[str, Any] = {"id": details.id, "name": details.name}
208
+ tool_result: dict[str, Any] = {"id": details.id, "name": details.name}
216
209
 
217
210
  try:
218
211
  result = await handler(arguments, details)
@@ -225,10 +218,10 @@ class Prompts:
225
218
  return ToolResult(**tool_result, result=str(exception), is_error=True)
226
219
 
227
220
  async def _handle_tool_calls(
228
- self, result: FinishedEvent, options: Union[RunPromptOptions, ChatPromptOptions]
229
- ) -> Optional[FinishedEvent]:
221
+ self, result: FinishedResult, options: Union[RunPromptOptions, ChatPromptOptions]
222
+ ) -> Optional[FinishedResult]:
230
223
  # Seems Python cannot infer the type
231
- assert result.response.type == StreamTypes.Text and result.response.tool_calls is not None
224
+ assert result.response.type == StreamTypes.Text and result.tool_requests is not None
232
225
 
233
226
  if not options.tools:
234
227
  raise ApiError(
@@ -238,7 +231,7 @@ class Prompts:
238
231
  response="Tools not supplied",
239
232
  )
240
233
 
241
- for tool_call in result.response.tool_calls:
234
+ for tool_call in result.tool_requests:
242
235
  if tool_call.name not in options.tools:
243
236
  raise ApiError(
244
237
  status=400,
@@ -258,10 +251,10 @@ class Prompts:
258
251
  conversation_uuid=result.uuid,
259
252
  messages=result.conversation,
260
253
  pause_execution=self._pause_tool_execution,
261
- requested_tool_calls=result.response.tool_calls,
254
+ requested_tool_calls=result.tool_requests,
262
255
  ),
263
256
  )
264
- for tool_call in result.response.tool_calls
257
+ for tool_call in result.tool_requests
265
258
  ],
266
259
  return_exceptions=False,
267
260
  )
@@ -282,7 +275,7 @@ class Prompts:
282
275
 
283
276
  next_result = await self.chat(result.uuid, tool_messages, ChatPromptOptions(**dict(options)))
284
277
 
285
- return FinishedEvent(**dict(next_result)) if next_result else None
278
+ return FinishedResult(**dict(next_result)) if next_result else None
286
279
 
287
280
  async def get(self, path: str, options: Optional[GetPromptOptions] = None) -> GetPromptResult:
288
281
  options = GetPromptOptions(**{**dict(self._options), **dict(options or {})})
@@ -299,6 +292,20 @@ class Prompts:
299
292
  ) as response:
300
293
  return GetPromptResult.model_validate_json(response.content)
301
294
 
295
+ async def get_all(self, options: Optional[GetAllPromptOptions] = None) -> List[GetPromptResult]:
296
+ options = GetAllPromptOptions(**{**dict(self._options), **dict(options or {})})
297
+ self._ensure_prompt_options(options)
298
+ assert options.project_id is not None
299
+
300
+ async with self._client.request(
301
+ handler=RequestHandler.GetAllPrompts,
302
+ params=GetAllPromptRequestParams(
303
+ project_id=options.project_id,
304
+ version_uuid=options.version_uuid,
305
+ ),
306
+ ) as response:
307
+ return _GetAllPromptResults.validate_json(response.content)
308
+
302
309
  async def get_or_create(
303
310
  self, path: str, options: Optional[GetOrCreatePromptOptions] = None
304
311
  ) -> GetOrCreatePromptResult:
@@ -343,7 +350,7 @@ class Prompts:
343
350
  else:
344
351
  result = RunPromptResult.model_validate_json(response.content)
345
352
 
346
- if options.tools and result.response.type == StreamTypes.Text and result.response.tool_calls:
353
+ if options.tools and result.response.type == StreamTypes.Text and result.tool_requests:
347
354
  try:
348
355
  # NOTE: The last sdk.chat called will already call on_finished
349
356
  final_result = await self._handle_tool_calls(result, options)
@@ -352,7 +359,7 @@ class Prompts:
352
359
  pass
353
360
 
354
361
  if options.on_finished:
355
- options.on_finished(FinishedEvent(**dict(result)))
362
+ options.on_finished(FinishedResult(**dict(result)))
356
363
 
357
364
  return RunPromptResult(**dict(result))
358
365
 
@@ -395,7 +402,7 @@ class Prompts:
395
402
  else:
396
403
  result = ChatPromptResult.model_validate_json(response.content)
397
404
 
398
- if options.tools and result.response.type == StreamTypes.Text and result.response.tool_calls:
405
+ if options.tools and result.response.type == StreamTypes.Text and result.tool_requests:
399
406
  try:
400
407
  # NOTE: The last sdk.chat called will already call on_finished
401
408
  final_result = await self._handle_tool_calls(result, options)
@@ -404,7 +411,7 @@ class Prompts:
404
411
  pass
405
412
 
406
413
  if options.on_finished:
407
- options.on_finished(FinishedEvent(**dict(result)))
414
+ options.on_finished(FinishedResult(**dict(result)))
408
415
 
409
416
  return ChatPromptResult(**dict(result))
410
417
 
@@ -424,8 +431,8 @@ class Prompts:
424
431
 
425
432
  return None
426
433
 
427
- def _adapt_prompt_config(self, config: Dict[str, Any], adapter: Adapter) -> Dict[str, Any]:
428
- adapted_config: Dict[str, Any] = {}
434
+ def _adapt_prompt_config(self, config: dict[str, Any], adapter: Adapter) -> dict[str, Any]:
435
+ adapted_config: dict[str, Any] = {}
429
436
 
430
437
  # NOTE: Should we delete attributes not supported by the provider?
431
438
  for attr, value in config.items():
@@ -1,10 +1,10 @@
1
1
  from datetime import datetime
2
- from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Sequence, Union, runtime_checkable
2
+ from typing import Any, Callable, Literal, Optional, Protocol, Sequence, Union, runtime_checkable
3
3
 
4
4
  from promptl_ai import Message, MessageLike
5
5
 
6
6
  from latitude_sdk.sdk.errors import ApiError
7
- from latitude_sdk.util import Field, Model, StrEnum
7
+ from latitude_sdk.util import Adapter, Field, Model, StrEnum
8
8
 
9
9
 
10
10
  class DbErrorRef(Model):
@@ -19,14 +19,27 @@ class Providers(StrEnum):
19
19
  Mistral = "mistral"
20
20
  Azure = "azure"
21
21
  Google = "google"
22
+ GoogleVertex = "google_vertex"
23
+ AnthropicVertex = "anthropic_vertex"
22
24
  Custom = "custom"
23
25
 
24
26
 
27
+ class ParameterType(StrEnum):
28
+ Text = "text"
29
+ File = "file"
30
+ Image = "image"
31
+
32
+
33
+ class PromptParameter(Model):
34
+ type: ParameterType
35
+
36
+
25
37
  class Prompt(Model):
26
38
  uuid: str
27
39
  path: str
28
40
  content: str
29
- config: Dict[str, Any]
41
+ config: dict[str, Any]
42
+ parameters: dict[str, PromptParameter]
30
43
  provider: Optional[Providers] = None
31
44
 
32
45
 
@@ -49,7 +62,7 @@ class FinishReason(StrEnum):
49
62
  class ToolCall(Model):
50
63
  id: str
51
64
  name: str
52
- arguments: Dict[str, Any]
65
+ arguments: dict[str, Any]
53
66
 
54
67
 
55
68
  class ToolResult(Model):
@@ -67,7 +80,7 @@ class StreamTypes(StrEnum):
67
80
  class ChainTextResponse(Model):
68
81
  type: Literal[StreamTypes.Text] = Field(default=StreamTypes.Text, alias=str("streamType"))
69
82
  text: str
70
- tool_calls: Optional[List[ToolCall]] = Field(default=None, alias=str("toolCalls"))
83
+ tool_calls: list[ToolCall] = Field(alias=str("toolCalls"))
71
84
  usage: ModelUsage
72
85
 
73
86
 
@@ -89,66 +102,105 @@ class ChainError(Model):
89
102
  class StreamEvents(StrEnum):
90
103
  Latitude = "latitude-event"
91
104
  Provider = "provider-event"
92
- Finished = "finished-event"
93
105
 
94
106
 
95
- ProviderEvent = Dict[str, Any]
107
+ ProviderEvent = dict[str, Any]
96
108
 
97
109
 
98
110
  class ChainEvents(StrEnum):
99
- Step = "chain-step"
100
- StepCompleted = "chain-step-complete"
101
- Completed = "chain-complete"
102
- Error = "chain-error"
111
+ ChainStarted = "chain-started"
112
+ StepStarted = "step-started"
113
+ ProviderStarted = "provider-started"
114
+ ProviderCompleted = "provider-completed"
115
+ ToolsStarted = "tools-started"
116
+ ToolCompleted = "tool-completed"
117
+ StepCompleted = "step-completed"
118
+ ChainCompleted = "chain-completed"
119
+ ChainError = "chain-error"
120
+ ToolsRequested = "tools-requested"
121
+
122
+
123
+ class GenericChainEvent(Model):
124
+ event: Literal[StreamEvents.Latitude] = StreamEvents.Latitude
125
+ messages: list[Message]
126
+ uuid: str
103
127
 
104
128
 
105
- class ChainEventStep(Model):
106
- event: Literal[StreamEvents.Latitude] = StreamEvents.Latitude
107
- type: Literal[ChainEvents.Step] = ChainEvents.Step
108
- uuid: Optional[str] = None
109
- is_last_step: bool = Field(alias=str("isLastStep"))
110
- config: Dict[str, Any]
111
- messages: List[Message]
129
+ class ChainEventChainStarted(GenericChainEvent):
130
+ type: Literal[ChainEvents.ChainStarted] = ChainEvents.ChainStarted
112
131
 
113
132
 
114
- class ChainEventStepCompleted(Model):
115
- event: Literal[StreamEvents.Latitude] = StreamEvents.Latitude
116
- type: Literal[ChainEvents.StepCompleted] = ChainEvents.StepCompleted
117
- uuid: Optional[str] = None
118
- response: ChainResponse
133
+ class ChainEventStepStarted(GenericChainEvent):
134
+ type: Literal[ChainEvents.StepStarted] = ChainEvents.StepStarted
119
135
 
120
136
 
121
- class ChainEventCompleted(Model):
122
- event: Literal[StreamEvents.Latitude] = StreamEvents.Latitude
123
- type: Literal[ChainEvents.Completed] = ChainEvents.Completed
124
- uuid: Optional[str] = None
137
+ class ChainEventProviderStarted(GenericChainEvent):
138
+ type: Literal[ChainEvents.ProviderStarted] = ChainEvents.ProviderStarted
139
+ config: dict[str, Any]
140
+
141
+
142
+ class ChainEventProviderCompleted(GenericChainEvent):
143
+ type: Literal[ChainEvents.ProviderCompleted] = ChainEvents.ProviderCompleted
144
+ provider_log_uuid: str = Field(alias=str("providerLogUuid"))
145
+ token_usage: ModelUsage = Field(alias=str("tokenUsage"))
125
146
  finish_reason: FinishReason = Field(alias=str("finishReason"))
126
- config: Dict[str, Any]
127
- messages: Optional[List[Message]] = None
128
- object: Optional[Any] = None
129
147
  response: ChainResponse
130
148
 
131
149
 
132
- class ChainEventError(Model):
133
- event: Literal[StreamEvents.Latitude] = StreamEvents.Latitude
134
- type: Literal[ChainEvents.Error] = ChainEvents.Error
150
+ class ChainEventToolsStarted(GenericChainEvent):
151
+ type: Literal[ChainEvents.ToolsStarted] = ChainEvents.ToolsStarted
152
+ tools: list[ToolCall]
153
+
154
+
155
+ class ChainEventToolCompleted(GenericChainEvent):
156
+ type: Literal[ChainEvents.ToolCompleted] = ChainEvents.ToolCompleted
157
+
158
+
159
+ class ChainEventStepCompleted(GenericChainEvent):
160
+ type: Literal[ChainEvents.StepCompleted] = ChainEvents.StepCompleted
161
+
162
+
163
+ class ChainEventChainCompleted(GenericChainEvent):
164
+ type: Literal[ChainEvents.ChainCompleted] = ChainEvents.ChainCompleted
165
+ token_usage: ModelUsage = Field(alias=str("tokenUsage"))
166
+ finish_reason: FinishReason = Field(alias=str("finishReason"))
167
+
168
+
169
+ class ChainEventChainError(GenericChainEvent):
170
+ type: Literal[ChainEvents.ChainError] = ChainEvents.ChainError
135
171
  error: ChainError
136
172
 
137
173
 
138
- ChainEvent = Union[ChainEventStep, ChainEventStepCompleted, ChainEventCompleted, ChainEventError]
174
+ class ChainEventToolsRequested(GenericChainEvent):
175
+ type: Literal[ChainEvents.ToolsRequested] = ChainEvents.ToolsRequested
176
+ tools: list[ToolCall]
177
+
139
178
 
179
+ ChainEvent = Union[
180
+ ChainEventChainStarted,
181
+ ChainEventStepStarted,
182
+ ChainEventProviderStarted,
183
+ ChainEventProviderCompleted,
184
+ ChainEventToolsStarted,
185
+ ChainEventToolCompleted,
186
+ ChainEventStepCompleted,
187
+ ChainEventChainCompleted,
188
+ ChainEventChainError,
189
+ ChainEventToolsRequested,
190
+ ]
140
191
 
141
192
  LatitudeEvent = ChainEvent
193
+ _LatitudeEvent = Adapter[LatitudeEvent](LatitudeEvent)
142
194
 
143
195
 
144
- class FinishedEvent(Model):
145
- event: Literal[StreamEvents.Finished] = StreamEvents.Finished
196
+ class FinishedResult(Model):
146
197
  uuid: str
147
- conversation: List[Message]
198
+ conversation: list[Message]
148
199
  response: ChainResponse
200
+ tool_requests: list[ToolCall] = Field(alias=str("toolRequests"))
149
201
 
150
202
 
151
- StreamEvent = Union[ProviderEvent, LatitudeEvent, FinishedEvent]
203
+ StreamEvent = Union[ProviderEvent, LatitudeEvent]
152
204
 
153
205
 
154
206
  class LogSources(StrEnum):
@@ -166,7 +218,7 @@ class Log(Model):
166
218
  commit_id: int = Field(alias=str("commitId"))
167
219
  resolved_content: str = Field(alias=str("resolvedContent"))
168
220
  content_hash: str = Field(alias=str("contentHash"))
169
- parameters: Dict[str, Any]
221
+ parameters: dict[str, Any]
170
222
  custom_identifier: Optional[str] = Field(default=None, alias=str("customIdentifier"))
171
223
  duration: Optional[int] = None
172
224
  created_at: datetime = Field(alias=str("createdAt"))
@@ -204,7 +256,7 @@ class StreamCallbacks(Model):
204
256
 
205
257
  @runtime_checkable
206
258
  class OnFinished(Protocol):
207
- def __call__(self, event: FinishedEvent): ...
259
+ def __call__(self, result: FinishedResult): ...
208
260
 
209
261
  on_finished: Optional[OnFinished] = None
210
262
 
@@ -219,27 +271,27 @@ class OnToolCallDetails(Model):
219
271
  id: str
220
272
  name: str
221
273
  conversation_uuid: str
222
- messages: List[Message]
274
+ messages: list[Message]
223
275
  pause_execution: Callable[[], ToolResult]
224
- requested_tool_calls: List[ToolCall]
276
+ requested_tool_calls: list[ToolCall]
225
277
 
226
278
 
227
279
  @runtime_checkable
228
280
  class OnToolCall(Protocol):
229
- async def __call__(self, arguments: Dict[str, Any], details: OnToolCallDetails) -> Any: ...
281
+ async def __call__(self, arguments: dict[str, Any], details: OnToolCallDetails) -> Any: ...
230
282
 
231
283
 
232
284
  @runtime_checkable
233
285
  class OnStep(Protocol):
234
286
  async def __call__(
235
- self, messages: List[MessageLike], config: Dict[str, Any]
287
+ self, messages: list[MessageLike], config: dict[str, Any]
236
288
  ) -> Union[str, MessageLike, Sequence[MessageLike]]: ...
237
289
 
238
290
 
239
291
  class SdkOptions(Model):
240
292
  project_id: Optional[int] = None
241
293
  version_uuid: Optional[str] = None
242
- tools: Optional[Dict[str, OnToolCall]] = None
294
+ tools: Optional[dict[str, OnToolCall]] = None
243
295
 
244
296
 
245
297
  class GatewayOptions(Model):