vectorvein 0.1.88__py3-none-any.whl → 0.1.89__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,527 +1,13 @@
1
- # @Author: Bi Ying
2
- # @Date: 2024-06-17 23:47:49
3
- import json
4
- from functools import cached_property
5
- from typing import Iterable, Literal, Generator, AsyncGenerator, overload, Any
1
+ from ..types.enums import BackendType
2
+ from ..types.defaults import GEMINI_DEFAULT_MODEL
3
+ from .openai_compatible_client import OpenAICompatibleChatClient, AsyncOpenAICompatibleChatClient
6
4
 
7
- import httpx
8
5
 
9
- from .utils import cutoff_messages
10
- from ..types import defaults as defs
11
- from .base_client import BaseChatClient, BaseAsyncChatClient
12
- from ..types.enums import ContextLengthControlType, BackendType
13
- from ..types.llm_parameters import (
14
- NotGiven,
15
- NOT_GIVEN,
16
- ToolParam,
17
- ToolChoice,
18
- ChatCompletionMessage,
19
- ChatCompletionDeltaMessage,
20
- ChatCompletionStreamOptionsParam,
21
- )
6
+ class GeminiChatClient(OpenAICompatibleChatClient):
7
+ DEFAULT_MODEL = GEMINI_DEFAULT_MODEL
8
+ BACKEND_NAME = BackendType.Gemini
22
9
 
23
10
 
24
- class GeminiChatClient(BaseChatClient):
25
- DEFAULT_MODEL: str = defs.GEMINI_DEFAULT_MODEL
26
- BACKEND_NAME: BackendType = BackendType.Gemini
27
-
28
- def __init__(
29
- self,
30
- model: str = defs.GEMINI_DEFAULT_MODEL,
31
- stream: bool = True,
32
- temperature: float | None | NotGiven = NOT_GIVEN,
33
- context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
34
- random_endpoint: bool = True,
35
- endpoint_id: str = "",
36
- http_client: httpx.Client | None = None,
37
- backend_name: str | None = None,
38
- ):
39
- super().__init__(
40
- model,
41
- stream,
42
- temperature,
43
- context_length_control,
44
- random_endpoint,
45
- endpoint_id,
46
- http_client,
47
- backend_name,
48
- )
49
- self.model_id = None
50
- self.endpoint = None
51
-
52
- @cached_property
53
- def raw_client(self):
54
- self.endpoint, self.model_id = self._set_endpoint()
55
- if not self.http_client:
56
- self.http_client = httpx.Client(timeout=300, proxy=self.endpoint.proxy)
57
- return self.http_client
58
-
59
- @overload
60
- def create_completion(
61
- self,
62
- *,
63
- messages: list,
64
- model: str | None = None,
65
- stream: Literal[False] = False,
66
- temperature: float | None | NotGiven = NOT_GIVEN,
67
- max_tokens: int | None = None,
68
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
69
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
70
- response_format: dict | None = None,
71
- stream_options: ChatCompletionStreamOptionsParam | None = None,
72
- top_p: float | NotGiven | None = NOT_GIVEN,
73
- skip_cutoff: bool = False,
74
- **kwargs,
75
- ) -> ChatCompletionMessage:
76
- pass
77
-
78
- @overload
79
- def create_completion(
80
- self,
81
- *,
82
- messages: list,
83
- model: str | None = None,
84
- stream: Literal[True],
85
- temperature: float | None | NotGiven = NOT_GIVEN,
86
- max_tokens: int | None = None,
87
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
88
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
89
- response_format: dict | None = None,
90
- stream_options: ChatCompletionStreamOptionsParam | None = None,
91
- top_p: float | NotGiven | None = NOT_GIVEN,
92
- skip_cutoff: bool = False,
93
- **kwargs,
94
- ) -> Generator[ChatCompletionDeltaMessage, None, None]:
95
- pass
96
-
97
- @overload
98
- def create_completion(
99
- self,
100
- *,
101
- messages: list,
102
- model: str | None = None,
103
- stream: bool,
104
- temperature: float | None | NotGiven = NOT_GIVEN,
105
- max_tokens: int | None = None,
106
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
107
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
108
- response_format: dict | None = None,
109
- stream_options: ChatCompletionStreamOptionsParam | None = None,
110
- top_p: float | NotGiven | None = NOT_GIVEN,
111
- skip_cutoff: bool = False,
112
- **kwargs,
113
- ) -> ChatCompletionMessage | Generator[ChatCompletionDeltaMessage, Any, None]:
114
- pass
115
-
116
- def create_completion(
117
- self,
118
- *,
119
- messages: list,
120
- model: str | None = None,
121
- stream: Literal[False] | Literal[True] = False,
122
- temperature: float | None | NotGiven = NOT_GIVEN,
123
- max_tokens: int | None = None,
124
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
125
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
126
- response_format: dict | None = None,
127
- stream_options: ChatCompletionStreamOptionsParam | None = None,
128
- top_p: float | NotGiven | None = NOT_GIVEN,
129
- skip_cutoff: bool = False,
130
- **kwargs,
131
- ):
132
- if model is not None:
133
- self.model = model
134
- if stream is not None:
135
- self.stream = stream
136
- if temperature is not None:
137
- self.temperature = temperature
138
-
139
- self.model_setting = self.backend_settings.models[self.model]
140
- if self.model_id is None:
141
- self.model_id = self.model_setting.id
142
-
143
- self.endpoint, self.model_id = self._set_endpoint()
144
-
145
- if messages[0].get("role") == "system":
146
- system_prompt = messages[0]["content"]
147
- messages = messages[1:]
148
- else:
149
- system_prompt = ""
150
-
151
- if not skip_cutoff and self.context_length_control == ContextLengthControlType.Latest:
152
- messages = cutoff_messages(
153
- messages,
154
- max_count=self.model_setting.context_length,
155
- backend=self.BACKEND_NAME,
156
- model=self.model_setting.id,
157
- )
158
-
159
- tools_params = {}
160
- if tools:
161
- tools_params = {"tools": [{"function_declarations": [tool["function"] for tool in tools]}]}
162
-
163
- response_format_params = {}
164
- if response_format is not None:
165
- if response_format.get("type") == "json_object":
166
- response_format_params = {"response_mime_type": "application/json"}
167
-
168
- top_p_params = {}
169
- if top_p:
170
- top_p_params = {"top_p": top_p}
171
-
172
- temperature_params = {}
173
- if temperature:
174
- temperature_params = {"temperature": temperature}
175
-
176
- request_body = {
177
- "contents": messages,
178
- "safetySettings": [
179
- {
180
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
181
- "threshold": "BLOCK_ONLY_HIGH",
182
- }
183
- ],
184
- "generationConfig": {
185
- "maxOutputTokens": max_tokens,
186
- **temperature_params,
187
- **top_p_params,
188
- **response_format_params,
189
- },
190
- **tools_params,
191
- **kwargs,
192
- }
193
- if system_prompt:
194
- request_body["systemInstruction"] = {"parts": [{"text": system_prompt}]}
195
-
196
- headers = {"Content-Type": "application/json"}
197
-
198
- params = {"key": self.endpoint.api_key}
199
-
200
- if self.stream:
201
- url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:streamGenerateContent"
202
- params["alt"] = "sse"
203
-
204
- def generator():
205
- result = {"content": "", "tool_calls": [], "usage": {}}
206
- client = self.raw_client
207
- with client.stream("POST", url, headers=headers, params=params, json=request_body) as response:
208
- for chunk in response.iter_lines():
209
- message = {"content": "", "tool_calls": []}
210
- if not chunk.startswith("data:"):
211
- continue
212
- data = json.loads(chunk[5:])
213
- chunk_content = data["candidates"][0]["content"]["parts"][0]
214
- if "text" in chunk_content:
215
- message["content"] = chunk_content["text"]
216
- result["content"] += message["content"]
217
- elif "functionCall" in chunk_content:
218
- message["tool_calls"] = [
219
- {
220
- "index": 0,
221
- "id": "call_0",
222
- "function": {
223
- "arguments": json.dumps(
224
- chunk_content["functionCall"]["args"], ensure_ascii=False
225
- ),
226
- "name": chunk_content["functionCall"]["name"],
227
- },
228
- "type": "function",
229
- }
230
- ]
231
-
232
- result["usage"] = message["usage"] = {
233
- "prompt_tokens": data["usageMetadata"].get("promptTokenCount", 0),
234
- "completion_tokens": data["usageMetadata"].get("candidatesTokenCount", 0),
235
- "total_tokens": data["usageMetadata"].get("totalTokenCount", 0),
236
- }
237
- yield ChatCompletionDeltaMessage(**message)
238
-
239
- return generator()
240
- else:
241
- url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:generateContent"
242
- client = self.raw_client
243
- response = client.post(url, json=request_body, headers=headers, params=params, timeout=None).json()
244
- if "error" in response:
245
- raise Exception(response["error"])
246
- result = {
247
- "content": "",
248
- "usage": {
249
- "prompt_tokens": response.get("usageMetadata", {}).get("promptTokenCount", 0),
250
- "completion_tokens": response.get("usageMetadata", {}).get("candidatesTokenCount", 0),
251
- "total_tokens": response.get("usageMetadata", {}).get("totalTokenCount", 0),
252
- },
253
- }
254
- tool_calls = []
255
- for part in response["candidates"][0]["content"]["parts"]:
256
- if "text" in part:
257
- result["content"] += part["text"]
258
- elif "functionCall" in part:
259
- tool_call = {
260
- "index": 0,
261
- "id": "call_0",
262
- "function": {
263
- "arguments": json.dumps(part["functionCall"]["args"], ensure_ascii=False),
264
- "name": part["functionCall"]["name"],
265
- },
266
- "type": "function",
267
- }
268
- tool_calls.append(tool_call)
269
-
270
- if tool_calls:
271
- result["tool_calls"] = tool_calls
272
-
273
- return ChatCompletionMessage(**result)
274
-
275
-
276
- class AsyncGeminiChatClient(BaseAsyncChatClient):
277
- DEFAULT_MODEL: str = defs.GEMINI_DEFAULT_MODEL
278
- BACKEND_NAME: BackendType = BackendType.Gemini
279
-
280
- def __init__(
281
- self,
282
- model: str = defs.GEMINI_DEFAULT_MODEL,
283
- stream: bool = True,
284
- temperature: float | None | NotGiven = NOT_GIVEN,
285
- context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
286
- random_endpoint: bool = True,
287
- endpoint_id: str = "",
288
- http_client: httpx.AsyncClient | None = None,
289
- backend_name: str | None = None,
290
- ):
291
- super().__init__(
292
- model,
293
- stream,
294
- temperature,
295
- context_length_control,
296
- random_endpoint,
297
- endpoint_id,
298
- http_client,
299
- backend_name,
300
- )
301
- self.model_id = None
302
- self.endpoint = None
303
-
304
- @cached_property
305
- def raw_client(self):
306
- self.endpoint, self.model_id = self._set_endpoint()
307
- if not self.http_client:
308
- self.http_client = httpx.AsyncClient(timeout=300, proxy=self.endpoint.proxy)
309
- return self.http_client
310
-
311
- @overload
312
- async def create_completion(
313
- self,
314
- *,
315
- messages: list,
316
- model: str | None = None,
317
- stream: Literal[False] = False,
318
- temperature: float | None | NotGiven = NOT_GIVEN,
319
- max_tokens: int | None = None,
320
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
321
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
322
- response_format: dict | None = None,
323
- stream_options: ChatCompletionStreamOptionsParam | None = None,
324
- top_p: float | NotGiven | None = NOT_GIVEN,
325
- skip_cutoff: bool = False,
326
- **kwargs,
327
- ) -> ChatCompletionMessage:
328
- pass
329
-
330
- @overload
331
- async def create_completion(
332
- self,
333
- *,
334
- messages: list,
335
- model: str | None = None,
336
- stream: Literal[True],
337
- temperature: float | None | NotGiven = NOT_GIVEN,
338
- max_tokens: int | None = None,
339
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
340
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
341
- response_format: dict | None = None,
342
- stream_options: ChatCompletionStreamOptionsParam | None = None,
343
- top_p: float | NotGiven | None = NOT_GIVEN,
344
- skip_cutoff: bool = False,
345
- **kwargs,
346
- ) -> AsyncGenerator[ChatCompletionDeltaMessage, Any]:
347
- pass
348
-
349
- @overload
350
- async def create_completion(
351
- self,
352
- *,
353
- messages: list,
354
- model: str | None = None,
355
- stream: bool,
356
- temperature: float | None | NotGiven = NOT_GIVEN,
357
- max_tokens: int | None = None,
358
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
359
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
360
- response_format: dict | None = None,
361
- stream_options: ChatCompletionStreamOptionsParam | None = None,
362
- top_p: float | NotGiven | None = NOT_GIVEN,
363
- skip_cutoff: bool = False,
364
- **kwargs,
365
- ) -> ChatCompletionMessage | AsyncGenerator[ChatCompletionDeltaMessage, Any]:
366
- pass
367
-
368
- async def create_completion(
369
- self,
370
- *,
371
- messages: list,
372
- model: str | None = None,
373
- stream: Literal[False] | Literal[True] = False,
374
- temperature: float | None | NotGiven = NOT_GIVEN,
375
- max_tokens: int | None = None,
376
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
377
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
378
- response_format: dict | None = None,
379
- stream_options: ChatCompletionStreamOptionsParam | None = None,
380
- top_p: float | NotGiven | None = NOT_GIVEN,
381
- skip_cutoff: bool = False,
382
- **kwargs,
383
- ):
384
- if model is not None:
385
- self.model = model
386
- if stream is not None:
387
- self.stream = stream
388
- if temperature is not None:
389
- self.temperature = temperature
390
-
391
- self.model_setting = self.backend_settings.models[self.model]
392
- if self.model_id is None:
393
- self.model_id = self.model_setting.id
394
-
395
- self.endpoint, self.model_id = self._set_endpoint()
396
-
397
- if messages[0].get("role") == "system":
398
- system_prompt = messages[0]["content"]
399
- messages = messages[1:]
400
- else:
401
- system_prompt = ""
402
-
403
- if not skip_cutoff and self.context_length_control == ContextLengthControlType.Latest:
404
- messages = cutoff_messages(
405
- messages,
406
- max_count=self.model_setting.context_length,
407
- backend=self.BACKEND_NAME,
408
- model=self.model_setting.id,
409
- )
410
-
411
- tools_params = {}
412
- if tools:
413
- tools_params = {"tools": [{"function_declarations": [tool["function"] for tool in tools]}]}
414
-
415
- response_format_params = {}
416
- if response_format is not None:
417
- if response_format.get("type") == "json_object":
418
- response_format_params = {"response_mime_type": "application/json"}
419
-
420
- top_p_params = {}
421
- if top_p:
422
- top_p_params = {"top_p": top_p}
423
-
424
- temperature_params = {}
425
- if temperature:
426
- temperature_params = {"temperature": temperature}
427
-
428
- request_body = {
429
- "contents": messages,
430
- "safetySettings": [
431
- {
432
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
433
- "threshold": "BLOCK_ONLY_HIGH",
434
- }
435
- ],
436
- "generationConfig": {
437
- "maxOutputTokens": max_tokens,
438
- **temperature_params,
439
- **top_p_params,
440
- **response_format_params,
441
- },
442
- **tools_params,
443
- **kwargs,
444
- }
445
- if system_prompt:
446
- request_body["systemInstruction"] = {"parts": [{"text": system_prompt}]}
447
-
448
- headers = {"Content-Type": "application/json"}
449
-
450
- params = {"key": self.endpoint.api_key}
451
-
452
- if self.stream:
453
- url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:streamGenerateContent"
454
- params["alt"] = "sse"
455
-
456
- async def generator():
457
- result = {"content": "", "tool_calls": [], "usage": {}}
458
- client = self.raw_client
459
- async with client.stream("POST", url, headers=headers, params=params, json=request_body) as response:
460
- async for chunk in response.aiter_lines():
461
- message = {"content": "", "tool_calls": []}
462
- if not chunk.startswith("data:"):
463
- continue
464
- data = json.loads(chunk[5:])
465
- chunk_content = data["candidates"][0]["content"]["parts"][0]
466
- if "text" in chunk_content:
467
- message["content"] = chunk_content["text"]
468
- result["content"] += message["content"]
469
- elif "functionCall" in chunk_content:
470
- message["tool_calls"] = [
471
- {
472
- "index": 0,
473
- "id": "call_0",
474
- "function": {
475
- "arguments": json.dumps(
476
- chunk_content["functionCall"]["args"], ensure_ascii=False
477
- ),
478
- "name": chunk_content["functionCall"]["name"],
479
- },
480
- "type": "function",
481
- }
482
- ]
483
-
484
- result["usage"] = message["usage"] = {
485
- "prompt_tokens": data["usageMetadata"].get("promptTokenCount", 0),
486
- "completion_tokens": data["usageMetadata"].get("candidatesTokenCount", 0),
487
- "total_tokens": data["usageMetadata"].get("totalTokenCount", 0),
488
- }
489
- yield ChatCompletionDeltaMessage(**message)
490
-
491
- return generator()
492
- else:
493
- url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:generateContent"
494
- client = self.raw_client
495
- async with client:
496
- response = await client.post(url, json=request_body, headers=headers, params=params, timeout=None)
497
- response = response.json()
498
- if "error" in response:
499
- raise Exception(response["error"])
500
- result = {
501
- "content": "",
502
- "usage": {
503
- "prompt_tokens": response.get("usageMetadata", {}).get("promptTokenCount", 0),
504
- "completion_tokens": response.get("usageMetadata", {}).get("candidatesTokenCount", 0),
505
- "total_tokens": response.get("usageMetadata", {}).get("totalTokenCount", 0),
506
- },
507
- }
508
- tool_calls = []
509
- for part in response["candidates"][0]["content"]["parts"]:
510
- if "text" in part:
511
- result["content"] += part["text"]
512
- elif "functionCall" in part:
513
- tool_call = {
514
- "index": 0,
515
- "id": "call_0",
516
- "function": {
517
- "arguments": json.dumps(part["functionCall"]["args"], ensure_ascii=False),
518
- "name": part["functionCall"]["name"],
519
- },
520
- "type": "function",
521
- }
522
- tool_calls.append(tool_call)
523
-
524
- if tool_calls:
525
- result["tool_calls"] = tool_calls
526
-
527
- return ChatCompletionMessage(**result)
11
+ class AsyncGeminiChatClient(AsyncOpenAICompatibleChatClient):
12
+ DEFAULT_MODEL = GEMINI_DEFAULT_MODEL
13
+ BACKEND_NAME = BackendType.Gemini
@@ -212,6 +212,8 @@ class OpenAICompatibleChatClient(BaseChatClient):
212
212
  else:
213
213
  _stream_options_params = {}
214
214
 
215
+ self._acquire_rate_limit(self.endpoint, self.model, messages)
216
+
215
217
  if self.stream:
216
218
  stream_response = raw_client.chat.completions.create(
217
219
  model=self.model_id,
@@ -538,6 +540,8 @@ class AsyncOpenAICompatibleChatClient(BaseAsyncChatClient):
538
540
  else:
539
541
  max_tokens = self.model_setting.context_length - token_counts - 64
540
542
 
543
+ await self._acquire_rate_limit(self.endpoint, self.model, messages)
544
+
541
545
  if self.stream:
542
546
  stream_response = await raw_client.chat.completions.create(
543
547
  model=self.model_id,