latitude-sdk 0.1.0b8__py3-none-any.whl → 0.1.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -41,6 +41,8 @@ class Evaluations:
41
41
  self._client = client
42
42
 
43
43
  async def trigger(self, uuid: str, options: TriggerEvaluationOptions) -> TriggerEvaluationResult:
44
+ options = TriggerEvaluationOptions(**{**dict(self._options), **dict(options)})
45
+
44
46
  async with self._client.request(
45
47
  handler=RequestHandler.TriggerEvaluation,
46
48
  params=TriggerEvaluationRequestParams(
@@ -55,6 +57,8 @@ class Evaluations:
55
57
  async def create_result(
56
58
  self, uuid: str, evaluation_uuid: str, options: CreateEvaluationResultOptions
57
59
  ) -> CreateEvaluationResultResult:
60
+ options = CreateEvaluationResultOptions(**{**dict(self._options), **dict(options)})
61
+
58
62
  async with self._client.request(
59
63
  handler=RequestHandler.CreateEvaluationResult,
60
64
  params=CreateEvaluationResultRequestParams(
@@ -39,7 +39,7 @@ DEFAULT_INTERNAL_OPTIONS = InternalOptions(
39
39
 
40
40
 
41
41
  DEFAULT_LATITUDE_OPTIONS = LatitudeOptions(
42
- telemetry=None, # Note: Telemetry is opt-in
42
+ telemetry=None, # NOTE: Telemetry is opt-in
43
43
  internal=DEFAULT_INTERNAL_OPTIONS,
44
44
  )
45
45
 
latitude_sdk/sdk/logs.py CHANGED
@@ -32,9 +32,8 @@ class Logs:
32
32
  self._options = options
33
33
  self._client = client
34
34
 
35
- def _ensure_options(self, options: LogOptions) -> LogOptions:
36
- project_id = options.project_id or self._options.project_id
37
- if not project_id:
35
+ def _ensure_log_options(self, options: LogOptions):
36
+ if not options.project_id:
38
37
  raise ApiError(
39
38
  status=404,
40
39
  code=ApiErrorCodes.NotFoundError,
@@ -42,16 +41,11 @@ class Logs:
42
41
  response="Project ID is required",
43
42
  )
44
43
 
45
- version_uuid = options.version_uuid or self._options.version_uuid
46
-
47
- return LogOptions(project_id=project_id, version_uuid=version_uuid)
48
-
49
44
  async def create(
50
45
  self, path: str, messages: Sequence[Union[Message, Dict[str, Any]]], options: CreateLogOptions
51
46
  ) -> CreateLogResult:
52
- log_options = self._ensure_options(options)
53
- options = CreateLogOptions(**{**dict(options), **dict(log_options)})
54
-
47
+ options = CreateLogOptions(**{**dict(self._options), **dict(options)})
48
+ self._ensure_log_options(options)
55
49
  assert options.project_id is not None
56
50
 
57
51
  messages = [_Message.validate_python(message) for message in messages]
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, Union
2
3
 
3
4
  from latitude_sdk.client import (
@@ -21,15 +22,25 @@ from latitude_sdk.sdk.types import (
21
22
  ChainEventStepCompleted,
22
23
  FinishedEvent,
23
24
  Message,
25
+ OnToolCall,
26
+ OnToolCallDetails,
24
27
  Prompt,
25
28
  SdkOptions,
26
29
  StreamCallbacks,
27
30
  StreamEvents,
31
+ StreamTypes,
32
+ ToolMessage,
33
+ ToolResult,
34
+ ToolResultContent,
28
35
  _Message,
29
36
  )
30
37
  from latitude_sdk.util import Model
31
38
 
32
39
 
40
+ class OnToolCallPaused(Exception):
41
+ pass
42
+
43
+
33
44
  class PromptOptions(Model):
34
45
  project_id: Optional[int] = None
35
46
  version_uuid: Optional[str] = None
@@ -54,6 +65,7 @@ class GetOrCreatePromptResult(Prompt, Model):
54
65
  class RunPromptOptions(StreamCallbacks, PromptOptions, Model):
55
66
  custom_identifier: Optional[str] = None
56
67
  parameters: Optional[Dict[str, Any]] = None
68
+ tools: Optional[Dict[str, OnToolCall]] = None
57
69
  stream: Optional[bool] = None
58
70
 
59
71
 
@@ -62,6 +74,7 @@ class RunPromptResult(FinishedEvent, Model):
62
74
 
63
75
 
64
76
  class ChatPromptOptions(StreamCallbacks, Model):
77
+ tools: Optional[Dict[str, OnToolCall]] = None
65
78
  stream: Optional[bool] = None
66
79
 
67
80
 
@@ -77,9 +90,8 @@ class Prompts:
77
90
  self._options = options
78
91
  self._client = client
79
92
 
80
- def _ensure_options(self, options: PromptOptions) -> PromptOptions:
81
- project_id = options.project_id or self._options.project_id
82
- if not project_id:
93
+ def _ensure_prompt_options(self, options: PromptOptions):
94
+ if not options.project_id:
83
95
  raise ApiError(
84
96
  status=404,
85
97
  code=ApiErrorCodes.NotFoundError,
@@ -87,12 +99,8 @@ class Prompts:
87
99
  response="Project ID is required",
88
100
  )
89
101
 
90
- version_uuid = options.version_uuid or self._options.version_uuid
91
-
92
- return PromptOptions(project_id=project_id, version_uuid=version_uuid)
93
-
94
102
  async def _handle_stream(
95
- self, stream: AsyncGenerator[ClientEvent, Any], callbacks: StreamCallbacks
103
+ self, stream: AsyncGenerator[ClientEvent, Any], on_event: Optional[StreamCallbacks.OnEvent]
96
104
  ) -> FinishedEvent:
97
105
  uuid = None
98
106
  conversation: List[Message] = []
@@ -146,8 +154,8 @@ class Prompts:
146
154
  response=stream_event.data,
147
155
  )
148
156
 
149
- if callbacks.on_event:
150
- callbacks.on_event(event)
157
+ if on_event:
158
+ on_event(event)
151
159
 
152
160
  if not uuid or not response:
153
161
  raise ApiError(
@@ -160,10 +168,65 @@ class Prompts:
160
168
  # NOTE: FinishedEvent not in on_event
161
169
  return FinishedEvent(uuid=uuid, conversation=conversation, response=response)
162
170
 
163
- async def get(self, path: str, options: GetPromptOptions) -> GetPromptResult:
164
- prompt_options = self._ensure_options(options)
165
- options = GetPromptOptions(**{**dict(options), **dict(prompt_options)})
171
+ def _pause_tool_execution(self) -> ToolResult:
172
+ raise OnToolCallPaused()
173
+
174
+ async def _handle_tool_calls(
175
+ self, result: FinishedEvent, options: Union[RunPromptOptions, ChatPromptOptions]
176
+ ) -> Optional[FinishedEvent]:
177
+ # Seems Python cannot infer the type
178
+ assert result.response.type == StreamTypes.Text and result.response.tool_calls is not None
179
+
180
+ if not options.tools:
181
+ raise ApiError(
182
+ status=400,
183
+ code=ApiErrorCodes.AIRunError,
184
+ message="Tools not supplied",
185
+ response="Tools not supplied",
186
+ )
187
+
188
+ for tool_call in result.response.tool_calls:
189
+ if tool_call.name not in options.tools:
190
+ raise ApiError(
191
+ status=400,
192
+ code=ApiErrorCodes.AIRunError,
193
+ message=f"Tool {tool_call.name} not supplied",
194
+ response=f"Tool {tool_call.name} not supplied",
195
+ )
166
196
 
197
+ details = OnToolCallDetails(
198
+ conversation_uuid=result.uuid,
199
+ messages=result.conversation,
200
+ pause_execution=self._pause_tool_execution,
201
+ requested_tool_calls=result.response.tool_calls,
202
+ )
203
+
204
+ tool_results = await asyncio.gather(
205
+ *[options.tools[tool_call.name](tool_call, details) for tool_call in result.response.tool_calls],
206
+ return_exceptions=False,
207
+ )
208
+
209
+ tool_messages = [
210
+ ToolMessage(
211
+ content=[
212
+ ToolResultContent(
213
+ id=tool_result.id,
214
+ name=tool_result.name,
215
+ result=tool_result.result,
216
+ is_error=tool_result.is_error,
217
+ )
218
+ ]
219
+ )
220
+ for tool_result in tool_results
221
+ ]
222
+
223
+ next_result = await self.chat(result.uuid, tool_messages, ChatPromptOptions(**dict(options)))
224
+
225
+ return FinishedEvent(**dict(next_result)) if next_result else None
226
+
227
+ async def get(self, path: str, options: GetPromptOptions) -> GetPromptResult:
228
+ options = GetPromptOptions(**{**dict(self._options), **dict(options)})
229
+ self._ensure_prompt_options(options)
167
230
  assert options.project_id is not None
168
231
 
169
232
  async with self._client.request(
@@ -177,9 +240,8 @@ class Prompts:
177
240
  return GetPromptResult.model_validate_json(response.content)
178
241
 
179
242
  async def get_or_create(self, path: str, options: GetOrCreatePromptOptions) -> GetOrCreatePromptResult:
180
- prompt_options = self._ensure_options(options)
181
- options = GetOrCreatePromptOptions(**{**dict(options), **dict(prompt_options)})
182
-
243
+ options = GetOrCreatePromptOptions(**{**dict(self._options), **dict(options)})
244
+ self._ensure_prompt_options(options)
183
245
  assert options.project_id is not None
184
246
 
185
247
  async with self._client.request(
@@ -197,9 +259,8 @@ class Prompts:
197
259
 
198
260
  async def run(self, path: str, options: RunPromptOptions) -> Optional[RunPromptResult]:
199
261
  try:
200
- prompt_options = self._ensure_options(options)
201
- options = RunPromptOptions(**{**dict(options), **dict(prompt_options)})
202
-
262
+ options = RunPromptOptions(**{**dict(self._options), **dict(options)})
263
+ self._ensure_prompt_options(options)
203
264
  assert options.project_id is not None
204
265
 
205
266
  async with self._client.request(
@@ -216,21 +277,22 @@ class Prompts:
216
277
  ),
217
278
  ) as response:
218
279
  if options.stream:
219
- result = await self._handle_stream(
220
- response.sse(),
221
- callbacks=StreamCallbacks(
222
- on_event=options.on_event,
223
- on_finished=options.on_finished,
224
- on_error=options.on_error,
225
- ),
226
- )
280
+ result = await self._handle_stream(response.sse(), options.on_event)
227
281
  else:
228
282
  result = RunPromptResult.model_validate_json(response.content)
229
283
 
230
- if options.on_finished:
231
- options.on_finished(FinishedEvent(**dict(result)))
284
+ if options.tools and result.response.type == StreamTypes.Text and result.response.tool_calls:
285
+ try:
286
+ # NOTE: The last sdk.chat called will already call on_finished
287
+ final_result = await self._handle_tool_calls(result, options)
288
+ return RunPromptResult(**dict(final_result)) if final_result else None
289
+ except OnToolCallPaused:
290
+ pass
291
+
292
+ if options.on_finished:
293
+ options.on_finished(FinishedEvent(**dict(result)))
232
294
 
233
- return RunPromptResult(**dict(result))
295
+ return RunPromptResult(**dict(result))
234
296
 
235
297
  except Exception as exception:
236
298
  if not isinstance(exception, ApiError):
@@ -252,6 +314,8 @@ class Prompts:
252
314
  self, uuid: str, messages: Sequence[Union[Message, Dict[str, Any]]], options: ChatPromptOptions
253
315
  ) -> Optional[ChatPromptResult]:
254
316
  try:
317
+ options = ChatPromptOptions(**{**dict(self._options), **dict(options)})
318
+
255
319
  messages = [_Message.validate_python(message) for message in messages]
256
320
 
257
321
  async with self._client.request(
@@ -265,21 +329,22 @@ class Prompts:
265
329
  ),
266
330
  ) as response:
267
331
  if options.stream:
268
- result = await self._handle_stream(
269
- response.sse(),
270
- callbacks=StreamCallbacks(
271
- on_event=options.on_event,
272
- on_finished=options.on_finished,
273
- on_error=options.on_error,
274
- ),
275
- )
332
+ result = await self._handle_stream(response.sse(), options.on_event)
276
333
  else:
277
334
  result = ChatPromptResult.model_validate_json(response.content)
278
335
 
279
- if options.on_finished:
280
- options.on_finished(FinishedEvent(**dict(result)))
336
+ if options.tools and result.response.type == StreamTypes.Text and result.response.tool_calls:
337
+ try:
338
+ # NOTE: The last sdk.chat called will already call on_finished
339
+ final_result = await self._handle_tool_calls(result, options)
340
+ return ChatPromptResult(**dict(final_result)) if final_result else None
341
+ except OnToolCallPaused:
342
+ pass
343
+
344
+ if options.on_finished:
345
+ options.on_finished(FinishedEvent(**dict(result)))
281
346
 
282
- return ChatPromptResult(**dict(result))
347
+ return ChatPromptResult(**dict(result))
283
348
 
284
349
  except Exception as exception:
285
350
  if not isinstance(exception, ApiError):
latitude_sdk/sdk/types.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime
2
- from typing import Any, Dict, List, Literal, Optional, Protocol, Union, runtime_checkable
2
+ from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Union, runtime_checkable
3
3
 
4
4
  from latitude_sdk.sdk.errors import ApiError
5
5
  from latitude_sdk.util import Adapter, Field, Model, StrEnum
@@ -63,7 +63,7 @@ class ToolResultContent(Model):
63
63
  type: Literal[ContentType.ToolResult] = ContentType.ToolResult
64
64
  id: str = Field(alias=str("toolCallId"))
65
65
  name: str = Field(alias=str("toolName"))
66
- result: str
66
+ result: Any
67
67
  is_error: Optional[bool] = Field(default=None, alias=str("isError"))
68
68
 
69
69
 
@@ -115,12 +115,29 @@ class ModelUsage(Model):
115
115
  total_tokens: int = Field(alias=str("totalTokens"))
116
116
 
117
117
 
118
+ class FinishReason(StrEnum):
119
+ Stop = "stop"
120
+ Length = "length"
121
+ ContentFilter = "content-filter"
122
+ ToolCalls = "tool-calls"
123
+ Error = "error"
124
+ Other = "other"
125
+ Unknown = "unknown"
126
+
127
+
118
128
  class ToolCall(Model):
119
129
  id: str
120
130
  name: str
121
131
  arguments: Dict[str, Any]
122
132
 
123
133
 
134
+ class ToolResult(Model):
135
+ id: str
136
+ name: str
137
+ result: Any
138
+ is_error: Optional[bool] = None
139
+
140
+
124
141
  class StreamTypes(StrEnum):
125
142
  Text = "text"
126
143
  Object = "object"
@@ -184,6 +201,7 @@ class ChainEventCompleted(Model):
184
201
  event: Literal[StreamEvents.Latitude] = StreamEvents.Latitude
185
202
  type: Literal[ChainEvents.Completed] = ChainEvents.Completed
186
203
  uuid: Optional[str] = None
204
+ finish_reason: FinishReason = Field(alias=str("finishReason"))
187
205
  config: Dict[str, Any]
188
206
  messages: Optional[List[Message]] = None
189
207
  object: Optional[Any] = None
@@ -276,9 +294,22 @@ class StreamCallbacks(Model):
276
294
  on_error: Optional[OnError] = None
277
295
 
278
296
 
297
+ class OnToolCallDetails(Model):
298
+ conversation_uuid: str
299
+ messages: List[Message]
300
+ pause_execution: Callable[[], ToolResult]
301
+ requested_tool_calls: List[ToolCall]
302
+
303
+
304
+ @runtime_checkable
305
+ class OnToolCall(Protocol):
306
+ async def __call__(self, call: ToolCall, details: OnToolCallDetails) -> ToolResult: ...
307
+
308
+
279
309
  class SdkOptions(Model):
280
310
  project_id: Optional[int] = None
281
311
  version_uuid: Optional[str] = None
312
+ tools: Optional[Dict[str, OnToolCall]] = None
282
313
 
283
314
 
284
315
  class GatewayOptions(Model):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: latitude-sdk
3
- Version: 0.1.0b8
3
+ Version: 0.1.0b9
4
4
  Summary: Latitude SDK for Python
5
5
  Project-URL: repository, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python
6
6
  Project-URL: homepage, https://github.com/latitude-dev/latitude-llm/tree/main/packages/sdks/python#readme
@@ -8,13 +8,13 @@ latitude_sdk/env/__init__.py,sha256=66of5veJ-u1aNI025L65Rrj321AjrYevMqomTMYIrPQ,
8
8
  latitude_sdk/env/env.py,sha256=MnXexPOHE6aXcAszrDCbW7hzACUv4YtU1bfxpYwvHNw,455
9
9
  latitude_sdk/sdk/__init__.py,sha256=C9LlIjfnrS7KOK3-ruXKmbT77nSQMm23nZ6-t8sO8ME,137
10
10
  latitude_sdk/sdk/errors.py,sha256=9GlGdDE8LGy3dE2Ry_BipBg-tDbQx7LWXJfSnTJSSBE,1747
11
- latitude_sdk/sdk/evaluations.py,sha256=ASWfNfH124qeahzhAn-gb2Ep4QIew5uDveY5NbNsNfk,2086
12
- latitude_sdk/sdk/latitude.py,sha256=GdkwdxtTSmfldLo17pJme8MWpjuHQy7JFDRYLg0l4cg,2724
13
- latitude_sdk/sdk/logs.py,sha256=-_jYlxKW8Hgq1nZ7QtNbaEg1eWuMUbeeA32jtFsDvNk,2199
14
- latitude_sdk/sdk/prompts.py,sha256=vh2WYSEQbArnqadfJ9mBwtpmEBzEJObUkNB6O87S9QQ,10256
15
- latitude_sdk/sdk/types.py,sha256=RPJA3cM8tJ6udwS4gi0LLlJPPrS10mSb6q4OaKmNsU4,7903
11
+ latitude_sdk/sdk/evaluations.py,sha256=xmlFtnFxDtTfO4cJnnh6ExFnCQHan_b25KrH-I9MW6I,2267
12
+ latitude_sdk/sdk/latitude.py,sha256=XoSsM2_v3s_ndaMIeIrbWQT9f-1W1fAu6P75Zrw1iMQ,2724
13
+ latitude_sdk/sdk/logs.py,sha256=cvLFW4xNaJ5XHS9W2YgK18NC6c7N1oCo61Qe6CqXxag,1968
14
+ latitude_sdk/sdk/prompts.py,sha256=Aez9Y9IP5rKyQ3zeTs1zmbXwRatQrvsyNxaTcoTndQ8,12726
15
+ latitude_sdk/sdk/types.py,sha256=Mvu4cUOhIa8v6j0t91m5M1mD9X9Qr_ZWEtvgv7IZGas,8653
16
16
  latitude_sdk/util/__init__.py,sha256=alIDGBnxWH4JvP-UW-7N99seBBi0r1GV1h8f1ERFBec,21
17
17
  latitude_sdk/util/utils.py,sha256=06phYKGKnlO0WcU3coMXpqyrHOVDvE0mFzvqTRJGbD8,2916
18
- latitude_sdk-0.1.0b8.dist-info/METADATA,sha256=Lsolipk18n_fM7hvfIXaTDuJZ7rOQLXZIr_fWmq3eKw,2031
19
- latitude_sdk-0.1.0b8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
20
- latitude_sdk-0.1.0b8.dist-info/RECORD,,
18
+ latitude_sdk-0.1.0b9.dist-info/METADATA,sha256=FnKMDcMpee49w2lB8MT3mOBJiEUz3lcQ2T16tb69hZU,2031
19
+ latitude_sdk-0.1.0b9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
20
+ latitude_sdk-0.1.0b9.dist-info/RECORD,,