vectorvein 0.1.23__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,11 +2,13 @@
2
2
  # @Date: 2024-07-26 14:48:55
3
3
  import json
4
4
  import random
5
+ from functools import cached_property
6
+ from typing import overload, Generator, AsyncGenerator, Any, Literal, Iterable
5
7
 
6
8
  import httpx
7
9
  from openai._types import NotGiven as OpenAINotGiven
8
10
  from anthropic import Anthropic, AnthropicVertex, AsyncAnthropic, AsyncAnthropicVertex
9
- from anthropic._types import NotGiven, NOT_GIVEN
11
+ from anthropic._types import NOT_GIVEN
10
12
  from anthropic.types import (
11
13
  TextBlock,
12
14
  ToolUseBlock,
@@ -24,15 +26,25 @@ from ..types import defaults as defs
24
26
  from .utils import cutoff_messages, get_message_token_counts
25
27
  from .base_client import BaseChatClient, BaseAsyncChatClient
26
28
  from ..types.enums import ContextLengthControlType, BackendType
27
- from ..types.llm_parameters import ChatCompletionMessage, ChatCompletionDeltaMessage
29
+ from ..types.llm_parameters import (
30
+ Usage,
31
+ NotGiven,
32
+ ToolParam,
33
+ ToolChoice,
34
+ AnthropicToolParam,
35
+ AnthropicToolChoice,
36
+ ChatCompletionMessage,
37
+ ChatCompletionToolParam,
38
+ ChatCompletionDeltaMessage,
39
+ )
28
40
 
29
41
 
30
- def refactor_tool_use_params(tools: list):
42
+ def refactor_tool_use_params(tools: Iterable[ChatCompletionToolParam]) -> list[AnthropicToolParam]:
31
43
  return [
32
44
  {
33
45
  "name": tool["function"]["name"],
34
- "description": tool["function"]["description"],
35
- "input_schema": tool["function"]["parameters"],
46
+ "description": tool["function"].get("description", ""),
47
+ "input_schema": tool["function"].get("parameters", {}),
36
48
  }
37
49
  for tool in tools
38
50
  ]
@@ -53,6 +65,17 @@ def refactor_tool_calls(tool_calls: list):
53
65
  ]
54
66
 
55
67
 
68
+ def refactor_tool_choice(tool_choice: ToolChoice) -> AnthropicToolChoice:
69
+ if isinstance(tool_choice, str):
70
+ if tool_choice == "auto":
71
+ return {"type": "auto"}
72
+ elif tool_choice == "required":
73
+ return {"type": "any"}
74
+ elif isinstance(tool_choice, dict) and "function" in tool_choice:
75
+ return {"type": "tool", "name": tool_choice["function"]["name"]}
76
+ return {"type": "auto"}
77
+
78
+
56
79
  def format_messages_alternate(messages: list) -> list:
57
80
  # messages: roles must alternate between "user" and "assistant", and not multiple "user" roles in a row
58
81
  # reformat multiple "user" roles in a row into {"role": "user", "content": [{"type": "text", "text": "Hello, Claude"}, {"type": "text", "text": "How are you?"}]}
@@ -87,7 +110,7 @@ def format_messages_alternate(messages: list) -> list:
87
110
 
88
111
 
89
112
  class AnthropicChatClient(BaseChatClient):
90
- DEFAULT_MODEL: str = defs.ANTHROPIC_DEFAULT_MODEL
113
+ DEFAULT_MODEL: str | None = defs.ANTHROPIC_DEFAULT_MODEL
91
114
  BACKEND_NAME: BackendType = BackendType.Anthropic
92
115
 
93
116
  def __init__(
@@ -112,7 +135,7 @@ class AnthropicChatClient(BaseChatClient):
112
135
  **kwargs,
113
136
  )
114
137
 
115
- @property
138
+ @cached_property
116
139
  def raw_client(self):
117
140
  if self.random_endpoint:
118
141
  self.random_endpoint = True
@@ -120,6 +143,8 @@ class AnthropicChatClient(BaseChatClient):
120
143
  self.endpoint = settings.get_endpoint(self.endpoint_id)
121
144
 
122
145
  if self.endpoint.is_vertex:
146
+ if self.endpoint.credentials is None:
147
+ raise ValueError("Anthropic Vertex endpoint requires credentials")
123
148
  self.creds = Credentials(
124
149
  token=self.endpoint.credentials.get("token"),
125
150
  refresh_token=self.endpoint.credentials.get("refresh_token"),
@@ -131,7 +156,7 @@ class AnthropicChatClient(BaseChatClient):
131
156
  expiry=_helpers.utcnow() - _helpers.REFRESH_THRESHOLD,
132
157
  rapt_token=self.endpoint.credentials.get("rapt_token"),
133
158
  trust_boundary=self.endpoint.credentials.get("trust_boundary"),
134
- universe_domain=self.endpoint.credentials.get("universe_domain"),
159
+ universe_domain=self.endpoint.credentials.get("universe_domain", "googleapis.com"),
135
160
  account=self.endpoint.credentials.get("account", ""),
136
161
  )
137
162
 
@@ -143,10 +168,11 @@ class AnthropicChatClient(BaseChatClient):
143
168
  else:
144
169
  base_url = f"{self.endpoint.api_base}{self.endpoint.region}-aiplatform/v1"
145
170
 
171
+ region = NOT_GIVEN if self.endpoint.region is None else self.endpoint.region
146
172
  return AnthropicVertex(
147
- region=self.endpoint.region,
173
+ region=region,
148
174
  base_url=base_url,
149
- project_id=self.endpoint.credentials.get("quota_project_id"),
175
+ project_id=self.endpoint.credentials.get("quota_project_id", NOT_GIVEN),
150
176
  access_token=self.creds.token,
151
177
  http_client=self.http_client,
152
178
  )
@@ -157,15 +183,46 @@ class AnthropicChatClient(BaseChatClient):
157
183
  http_client=self.http_client,
158
184
  )
159
185
 
186
+ @overload
187
+ def create_completion(
188
+ self,
189
+ messages: list,
190
+ model: str | None = None,
191
+ stream: Literal[False] = False,
192
+ temperature: float | None = None,
193
+ max_tokens: int | None = None,
194
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
195
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
196
+ response_format: dict | None = None,
197
+ **kwargs,
198
+ ) -> ChatCompletionMessage:
199
+ pass
200
+
201
+ @overload
202
+ def create_completion(
203
+ self,
204
+ messages: list,
205
+ model: str | None = None,
206
+ stream: Literal[True] = True,
207
+ temperature: float | None = None,
208
+ max_tokens: int | None = None,
209
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
210
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
211
+ response_format: dict | None = None,
212
+ **kwargs,
213
+ ) -> Generator[ChatCompletionDeltaMessage, None, None]:
214
+ pass
215
+
160
216
  def create_completion(
161
217
  self,
162
- messages: list = list,
218
+ messages: list,
163
219
  model: str | None = None,
164
220
  stream: bool | None = None,
165
221
  temperature: float | None = None,
166
222
  max_tokens: int | None = None,
167
- tools: list | NotGiven = NOT_GIVEN,
168
- tool_choice: str | NotGiven = NOT_GIVEN,
223
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
224
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
225
+ response_format: dict | None = None,
169
226
  **kwargs,
170
227
  ):
171
228
  if model is not None:
@@ -182,7 +239,7 @@ class AnthropicChatClient(BaseChatClient):
182
239
  self.model_setting = self.backend_settings.models[self.model]
183
240
 
184
241
  if messages[0].get("role") == "system":
185
- system_prompt = messages[0]["content"]
242
+ system_prompt: str = messages[0]["content"]
186
243
  messages = messages[1:]
187
244
  else:
188
245
  system_prompt = ""
@@ -197,7 +254,10 @@ class AnthropicChatClient(BaseChatClient):
197
254
 
198
255
  messages = format_messages_alternate(messages)
199
256
 
200
- tools_params = refactor_tool_use_params(tools) if tools else tools
257
+ tools_params: list[AnthropicToolParam] | NotGiven = refactor_tool_use_params(tools) if tools else NOT_GIVEN
258
+ tool_choice_param = NOT_GIVEN
259
+ if tool_choice:
260
+ tool_choice_param = refactor_tool_choice(tool_choice)
201
261
 
202
262
  if max_tokens is None:
203
263
  max_output_tokens = self.model_setting.max_output_tokens
@@ -208,24 +268,23 @@ class AnthropicChatClient(BaseChatClient):
208
268
  else:
209
269
  max_tokens = self.model_setting.context_length - token_counts
210
270
 
211
- response = self.raw_client.messages.create(
212
- model=self.model_setting.id,
213
- messages=messages,
214
- system=system_prompt,
215
- stream=self.stream,
216
- temperature=self.temperature,
217
- max_tokens=max_tokens,
218
- tools=tools_params,
219
- tool_choice=tool_choice,
220
- **kwargs,
221
- )
222
-
223
271
  if self.stream:
272
+ stream_response = self.raw_client.messages.create(
273
+ model=self.model_setting.id,
274
+ messages=messages,
275
+ system=system_prompt,
276
+ stream=True,
277
+ temperature=self.temperature,
278
+ max_tokens=max_tokens,
279
+ tools=tools_params,
280
+ tool_choice=tool_choice_param,
281
+ **kwargs,
282
+ )
224
283
 
225
284
  def generator():
226
- result = {"content": ""}
227
- for chunk in response:
228
- message = {"content": ""}
285
+ result = {"content": "", "usage": {}, "tool_calls": []}
286
+ for chunk in stream_response:
287
+ message = {"content": "", "tool_calls": []}
229
288
  if isinstance(chunk, RawMessageStartEvent):
230
289
  result["usage"] = {"prompt_tokens": chunk.message.usage.input_tokens}
231
290
  continue
@@ -268,10 +327,22 @@ class AnthropicChatClient(BaseChatClient):
268
327
  result["usage"]["total_tokens"] = (
269
328
  result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"]
270
329
  )
271
- yield ChatCompletionDeltaMessage(usage=result["usage"])
330
+ yield ChatCompletionDeltaMessage(usage=Usage(**result["usage"]))
272
331
 
273
332
  return generator()
274
333
  else:
334
+ response = self.raw_client.messages.create(
335
+ model=self.model_setting.id,
336
+ messages=messages,
337
+ system=system_prompt,
338
+ stream=False,
339
+ temperature=self.temperature,
340
+ max_tokens=max_tokens,
341
+ tools=tools_params,
342
+ tool_choice=tool_choice_param,
343
+ **kwargs,
344
+ )
345
+
275
346
  result = {
276
347
  "content": "",
277
348
  "usage": {
@@ -294,7 +365,7 @@ class AnthropicChatClient(BaseChatClient):
294
365
 
295
366
 
296
367
  class AsyncAnthropicChatClient(BaseAsyncChatClient):
297
- DEFAULT_MODEL: str = defs.ANTHROPIC_DEFAULT_MODEL
368
+ DEFAULT_MODEL: str | None = defs.ANTHROPIC_DEFAULT_MODEL
298
369
  BACKEND_NAME: BackendType = BackendType.Anthropic
299
370
 
300
371
  def __init__(
@@ -319,7 +390,7 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
319
390
  **kwargs,
320
391
  )
321
392
 
322
- @property
393
+ @cached_property
323
394
  def raw_client(self):
324
395
  if self.random_endpoint:
325
396
  self.random_endpoint = True
@@ -327,6 +398,8 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
327
398
  self.endpoint = settings.get_endpoint(self.endpoint_id)
328
399
 
329
400
  if self.endpoint.is_vertex:
401
+ if self.endpoint.credentials is None:
402
+ raise ValueError("Anthropic Vertex endpoint requires credentials")
330
403
  self.creds = Credentials(
331
404
  token=self.endpoint.credentials.get("token"),
332
405
  refresh_token=self.endpoint.credentials.get("refresh_token"),
@@ -338,7 +411,7 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
338
411
  expiry=_helpers.utcnow() - _helpers.REFRESH_THRESHOLD,
339
412
  rapt_token=self.endpoint.credentials.get("rapt_token"),
340
413
  trust_boundary=self.endpoint.credentials.get("trust_boundary"),
341
- universe_domain=self.endpoint.credentials.get("universe_domain"),
414
+ universe_domain=self.endpoint.credentials.get("universe_domain", "googleapis.com"),
342
415
  account=self.endpoint.credentials.get("account", ""),
343
416
  )
344
417
 
@@ -350,10 +423,11 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
350
423
  else:
351
424
  base_url = f"{self.endpoint.api_base}{self.endpoint.region}-aiplatform/v1"
352
425
 
426
+ region = NOT_GIVEN if self.endpoint.region is None else self.endpoint.region
353
427
  return AsyncAnthropicVertex(
354
- region=self.endpoint.region,
428
+ region=region,
355
429
  base_url=base_url,
356
- project_id=self.endpoint.credentials.get("quota_project_id"),
430
+ project_id=self.endpoint.credentials.get("quota_project_id", NOT_GIVEN),
357
431
  access_token=self.creds.token,
358
432
  http_client=self.http_client,
359
433
  )
@@ -364,15 +438,46 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
364
438
  http_client=self.http_client,
365
439
  )
366
440
 
441
+ @overload
442
+ async def create_completion(
443
+ self,
444
+ messages: list,
445
+ model: str | None = None,
446
+ stream: Literal[False] = False,
447
+ temperature: float | None = None,
448
+ max_tokens: int | None = None,
449
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
450
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
451
+ response_format: dict | None = None,
452
+ **kwargs,
453
+ ) -> ChatCompletionMessage:
454
+ pass
455
+
456
+ @overload
457
+ async def create_completion(
458
+ self,
459
+ messages: list,
460
+ model: str | None = None,
461
+ stream: Literal[True] = True,
462
+ temperature: float | None = None,
463
+ max_tokens: int | None = None,
464
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
465
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
466
+ response_format: dict | None = None,
467
+ **kwargs,
468
+ ) -> AsyncGenerator[ChatCompletionDeltaMessage, Any]:
469
+ pass
470
+
367
471
  async def create_completion(
368
472
  self,
369
- messages: list = list,
473
+ messages: list,
370
474
  model: str | None = None,
371
475
  stream: bool | None = None,
372
476
  temperature: float | None = None,
373
477
  max_tokens: int | None = None,
374
- tools: list | NotGiven = NOT_GIVEN,
375
- tool_choice: str | NotGiven = NOT_GIVEN,
478
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
479
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
480
+ response_format: dict | None = None,
376
481
  **kwargs,
377
482
  ):
378
483
  if model is not None:
@@ -404,7 +509,10 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
404
509
 
405
510
  messages = format_messages_alternate(messages)
406
511
 
407
- tools_params = refactor_tool_use_params(tools) if tools else tools
512
+ tools_params: list[AnthropicToolParam] | NotGiven = refactor_tool_use_params(tools) if tools else NOT_GIVEN
513
+ tool_choice_param = NOT_GIVEN
514
+ if tool_choice:
515
+ tool_choice_param = refactor_tool_choice(tool_choice)
408
516
 
409
517
  if max_tokens is None:
410
518
  max_output_tokens = self.model_setting.max_output_tokens
@@ -415,24 +523,23 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
415
523
  else:
416
524
  max_tokens = self.model_setting.context_length - token_counts
417
525
 
418
- response = await self.raw_client.messages.create(
419
- model=self.model_setting.id,
420
- messages=messages,
421
- system=system_prompt,
422
- stream=self.stream,
423
- temperature=self.temperature,
424
- max_tokens=max_tokens,
425
- tools=tools_params,
426
- tool_choice=tool_choice,
427
- **kwargs,
428
- )
429
-
430
526
  if self.stream:
527
+ stream_response = await self.raw_client.messages.create(
528
+ model=self.model_setting.id,
529
+ messages=messages,
530
+ system=system_prompt,
531
+ stream=True,
532
+ temperature=self.temperature,
533
+ max_tokens=max_tokens,
534
+ tools=tools_params,
535
+ tool_choice=tool_choice_param,
536
+ **kwargs,
537
+ )
431
538
 
432
539
  async def generator():
433
- result = {"content": ""}
434
- async for chunk in response:
435
- message = {"content": ""}
540
+ result = {"content": "", "usage": {}, "tool_calls": []}
541
+ async for chunk in stream_response:
542
+ message = {"content": "", "tool_calls": []}
436
543
  if isinstance(chunk, RawMessageStartEvent):
437
544
  result["usage"] = {"prompt_tokens": chunk.message.usage.input_tokens}
438
545
  continue
@@ -475,10 +582,22 @@ class AsyncAnthropicChatClient(BaseAsyncChatClient):
475
582
  result["usage"]["total_tokens"] = (
476
583
  result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"]
477
584
  )
478
- yield ChatCompletionDeltaMessage(usage=result["usage"])
585
+ yield ChatCompletionDeltaMessage(usage=Usage(**result["usage"]))
479
586
 
480
587
  return generator()
481
588
  else:
589
+ response = await self.raw_client.messages.create(
590
+ model=self.model_setting.id,
591
+ messages=messages,
592
+ system=system_prompt,
593
+ stream=False,
594
+ temperature=self.temperature,
595
+ max_tokens=max_tokens,
596
+ tools=tools_params,
597
+ tool_choice=tool_choice_param,
598
+ **kwargs,
599
+ )
600
+
482
601
  result = {
483
602
  "content": "",
484
603
  "usage": {
@@ -1,22 +1,29 @@
1
1
  # @Author: Bi Ying
2
2
  # @Date: 2024-07-26 14:48:55
3
3
  from abc import ABC, abstractmethod
4
- from typing import Generator, AsyncGenerator, Any
4
+ from functools import cached_property
5
+ from typing import Generator, AsyncGenerator, Any, overload, Literal, Iterable
5
6
 
6
7
  import httpx
7
- from openai._types import NotGiven, NOT_GIVEN
8
8
  from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
9
9
  from anthropic import Anthropic, AnthropicVertex, AsyncAnthropic, AsyncAnthropicVertex
10
10
 
11
11
  from ..settings import settings
12
12
  from ..types import defaults as defs
13
13
  from ..types.enums import ContextLengthControlType, BackendType
14
- from ..types.llm_parameters import ChatCompletionMessage, ChatCompletionDeltaMessage
14
+ from ..types.llm_parameters import (
15
+ NotGiven,
16
+ NOT_GIVEN,
17
+ ToolParam,
18
+ ToolChoice,
19
+ ChatCompletionMessage,
20
+ ChatCompletionDeltaMessage,
21
+ )
15
22
 
16
23
 
17
24
  class BaseChatClient(ABC):
18
25
  DEFAULT_MODEL: str | None = None
19
- BACKEND_NAME: BackendType | None = None
26
+ BACKEND_NAME: BackendType
20
27
 
21
28
  def __init__(
22
29
  self,
@@ -44,9 +51,41 @@ class BaseChatClient(ABC):
44
51
  self.random_endpoint = False
45
52
  self.endpoint = settings.get_endpoint(self.endpoint_id)
46
53
 
47
- @property
54
+ @cached_property
48
55
  @abstractmethod
49
- def raw_client(self) -> OpenAI | AzureOpenAI | Anthropic | AnthropicVertex:
56
+ def raw_client(self) -> OpenAI | AzureOpenAI | Anthropic | AnthropicVertex | httpx.Client | None:
57
+ pass
58
+
59
+ @overload
60
+ @abstractmethod
61
+ def create_completion(
62
+ self,
63
+ messages: list,
64
+ model: str | None = None,
65
+ stream: Literal[False] = False,
66
+ temperature: float = 0.7,
67
+ max_tokens: int | None = None,
68
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
69
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
70
+ response_format: dict | None = None,
71
+ **kwargs,
72
+ ) -> ChatCompletionMessage:
73
+ pass
74
+
75
+ @overload
76
+ @abstractmethod
77
+ def create_completion(
78
+ self,
79
+ messages: list,
80
+ model: str | None = None,
81
+ stream: Literal[True] = True,
82
+ temperature: float = 0.7,
83
+ max_tokens: int | None = None,
84
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
85
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
86
+ response_format: dict | None = None,
87
+ **kwargs,
88
+ ) -> Generator[ChatCompletionDeltaMessage, Any, None]:
50
89
  pass
51
90
 
52
91
  @abstractmethod
@@ -57,8 +96,8 @@ class BaseChatClient(ABC):
57
96
  stream: bool = False,
58
97
  temperature: float = 0.7,
59
98
  max_tokens: int | None = None,
60
- tools: list | NotGiven = NOT_GIVEN,
61
- tool_choice: str | NotGiven = NOT_GIVEN,
99
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
100
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
62
101
  response_format: dict | None = None,
63
102
  **kwargs,
64
103
  ) -> ChatCompletionMessage | Generator[ChatCompletionDeltaMessage, Any, None]:
@@ -70,8 +109,9 @@ class BaseChatClient(ABC):
70
109
  model: str | None = None,
71
110
  temperature: float = 0.7,
72
111
  max_tokens: int | None = None,
73
- tools: list | NotGiven = NOT_GIVEN,
74
- tool_choice: str | NotGiven = NOT_GIVEN,
112
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
113
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
114
+ response_format: dict | None = None,
75
115
  **kwargs,
76
116
  ) -> Generator[ChatCompletionDeltaMessage, Any, None]:
77
117
  return self.create_completion(
@@ -82,13 +122,14 @@ class BaseChatClient(ABC):
82
122
  max_tokens=max_tokens,
83
123
  tools=tools,
84
124
  tool_choice=tool_choice,
125
+ response_format=response_format,
85
126
  **kwargs,
86
127
  )
87
128
 
88
129
 
89
130
  class BaseAsyncChatClient(ABC):
90
131
  DEFAULT_MODEL: str | None = None
91
- BACKEND_NAME: BackendType | None = None
132
+ BACKEND_NAME: BackendType
92
133
 
93
134
  def __init__(
94
135
  self,
@@ -116,9 +157,43 @@ class BaseAsyncChatClient(ABC):
116
157
  self.random_endpoint = False
117
158
  self.endpoint = settings.get_endpoint(self.endpoint_id)
118
159
 
119
- @property
160
+ @cached_property
120
161
  @abstractmethod
121
- def raw_client(self) -> AsyncOpenAI | AsyncAzureOpenAI | AsyncAnthropic | AsyncAnthropicVertex:
162
+ def raw_client(
163
+ self,
164
+ ) -> AsyncOpenAI | AsyncAzureOpenAI | AsyncAnthropic | AsyncAnthropicVertex | httpx.AsyncClient | None:
165
+ pass
166
+
167
+ @overload
168
+ @abstractmethod
169
+ async def create_completion(
170
+ self,
171
+ messages: list,
172
+ model: str | None = None,
173
+ stream: Literal[False] = False,
174
+ temperature: float = 0.7,
175
+ max_tokens: int | None = None,
176
+ tools: list | NotGiven = NOT_GIVEN,
177
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
178
+ response_format: dict | None = None,
179
+ **kwargs,
180
+ ) -> ChatCompletionMessage:
181
+ pass
182
+
183
+ @overload
184
+ @abstractmethod
185
+ async def create_completion(
186
+ self,
187
+ messages: list,
188
+ model: str | None = None,
189
+ stream: Literal[True] = True,
190
+ temperature: float = 0.7,
191
+ max_tokens: int | None = None,
192
+ tools: list | NotGiven = NOT_GIVEN,
193
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
194
+ response_format: dict | None = None,
195
+ **kwargs,
196
+ ) -> AsyncGenerator[ChatCompletionDeltaMessage, None]:
122
197
  pass
123
198
 
124
199
  @abstractmethod
@@ -130,7 +205,7 @@ class BaseAsyncChatClient(ABC):
130
205
  temperature: float = 0.7,
131
206
  max_tokens: int | None = None,
132
207
  tools: list | NotGiven = NOT_GIVEN,
133
- tool_choice: str | NotGiven = NOT_GIVEN,
208
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
134
209
  response_format: dict | None = None,
135
210
  **kwargs,
136
211
  ) -> ChatCompletionMessage | AsyncGenerator[ChatCompletionDeltaMessage, None]:
@@ -143,7 +218,8 @@ class BaseAsyncChatClient(ABC):
143
218
  temperature: float = 0.7,
144
219
  max_tokens: int | None = None,
145
220
  tools: list | NotGiven = NOT_GIVEN,
146
- tool_choice: str | NotGiven = NOT_GIVEN,
221
+ tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
222
+ response_format: dict | None = None,
147
223
  **kwargs,
148
224
  ) -> AsyncGenerator[ChatCompletionDeltaMessage, None]:
149
225
  return await self.create_completion(
@@ -154,5 +230,6 @@ class BaseAsyncChatClient(ABC):
154
230
  max_tokens=max_tokens,
155
231
  tools=tools,
156
232
  tool_choice=tool_choice,
233
+ response_format=response_format,
157
234
  **kwargs,
158
235
  )