vectorvein 0.1.87__py3-none-any.whl → 0.1.89__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,527 +1,13 @@
1
- # @Author: Bi Ying
2
- # @Date: 2024-06-17 23:47:49
3
- import json
4
- from functools import cached_property
5
- from typing import Iterable, Literal, Generator, AsyncGenerator, overload, Any
1
+ from ..types.enums import BackendType
2
+ from ..types.defaults import GEMINI_DEFAULT_MODEL
3
+ from .openai_compatible_client import OpenAICompatibleChatClient, AsyncOpenAICompatibleChatClient
6
4
 
7
- import httpx
8
5
 
9
- from .utils import cutoff_messages
10
- from ..types import defaults as defs
11
- from .base_client import BaseChatClient, BaseAsyncChatClient
12
- from ..types.enums import ContextLengthControlType, BackendType
13
- from ..types.llm_parameters import (
14
- NotGiven,
15
- NOT_GIVEN,
16
- ToolParam,
17
- ToolChoice,
18
- ChatCompletionMessage,
19
- ChatCompletionDeltaMessage,
20
- ChatCompletionStreamOptionsParam,
21
- )
6
+ class GeminiChatClient(OpenAICompatibleChatClient):
7
+ DEFAULT_MODEL = GEMINI_DEFAULT_MODEL
8
+ BACKEND_NAME = BackendType.Gemini
22
9
 
23
10
 
24
- class GeminiChatClient(BaseChatClient):
25
- DEFAULT_MODEL: str = defs.GEMINI_DEFAULT_MODEL
26
- BACKEND_NAME: BackendType = BackendType.Gemini
27
-
28
- def __init__(
29
- self,
30
- model: str = defs.GEMINI_DEFAULT_MODEL,
31
- stream: bool = True,
32
- temperature: float | None | NotGiven = NOT_GIVEN,
33
- context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
34
- random_endpoint: bool = True,
35
- endpoint_id: str = "",
36
- http_client: httpx.Client | None = None,
37
- backend_name: str | None = None,
38
- ):
39
- super().__init__(
40
- model,
41
- stream,
42
- temperature,
43
- context_length_control,
44
- random_endpoint,
45
- endpoint_id,
46
- http_client,
47
- backend_name,
48
- )
49
- self.model_id = None
50
- self.endpoint = None
51
-
52
- @cached_property
53
- def raw_client(self):
54
- self.endpoint, self.model_id = self._set_endpoint()
55
- if not self.http_client:
56
- self.http_client = httpx.Client(timeout=300, proxy=self.endpoint.proxy)
57
- return self.http_client
58
-
59
- @overload
60
- def create_completion(
61
- self,
62
- *,
63
- messages: list,
64
- model: str | None = None,
65
- stream: Literal[False] = False,
66
- temperature: float | None | NotGiven = NOT_GIVEN,
67
- max_tokens: int | None = None,
68
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
69
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
70
- response_format: dict | None = None,
71
- stream_options: ChatCompletionStreamOptionsParam | None = None,
72
- top_p: float | NotGiven | None = NOT_GIVEN,
73
- skip_cutoff: bool = False,
74
- **kwargs,
75
- ) -> ChatCompletionMessage:
76
- pass
77
-
78
- @overload
79
- def create_completion(
80
- self,
81
- *,
82
- messages: list,
83
- model: str | None = None,
84
- stream: Literal[True],
85
- temperature: float | None | NotGiven = NOT_GIVEN,
86
- max_tokens: int | None = None,
87
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
88
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
89
- response_format: dict | None = None,
90
- stream_options: ChatCompletionStreamOptionsParam | None = None,
91
- top_p: float | NotGiven | None = NOT_GIVEN,
92
- skip_cutoff: bool = False,
93
- **kwargs,
94
- ) -> Generator[ChatCompletionDeltaMessage, None, None]:
95
- pass
96
-
97
- @overload
98
- def create_completion(
99
- self,
100
- *,
101
- messages: list,
102
- model: str | None = None,
103
- stream: bool,
104
- temperature: float | None | NotGiven = NOT_GIVEN,
105
- max_tokens: int | None = None,
106
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
107
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
108
- response_format: dict | None = None,
109
- stream_options: ChatCompletionStreamOptionsParam | None = None,
110
- top_p: float | NotGiven | None = NOT_GIVEN,
111
- skip_cutoff: bool = False,
112
- **kwargs,
113
- ) -> ChatCompletionMessage | Generator[ChatCompletionDeltaMessage, Any, None]:
114
- pass
115
-
116
- def create_completion(
117
- self,
118
- *,
119
- messages: list,
120
- model: str | None = None,
121
- stream: Literal[False] | Literal[True] = False,
122
- temperature: float | None | NotGiven = NOT_GIVEN,
123
- max_tokens: int | None = None,
124
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
125
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
126
- response_format: dict | None = None,
127
- stream_options: ChatCompletionStreamOptionsParam | None = None,
128
- top_p: float | NotGiven | None = NOT_GIVEN,
129
- skip_cutoff: bool = False,
130
- **kwargs,
131
- ):
132
- if model is not None:
133
- self.model = model
134
- if stream is not None:
135
- self.stream = stream
136
- if temperature is not None:
137
- self.temperature = temperature
138
-
139
- self.model_setting = self.backend_settings.models[self.model]
140
- if self.model_id is None:
141
- self.model_id = self.model_setting.id
142
-
143
- self.endpoint, self.model_id = self._set_endpoint()
144
-
145
- if messages[0].get("role") == "system":
146
- system_prompt = messages[0]["content"]
147
- messages = messages[1:]
148
- else:
149
- system_prompt = ""
150
-
151
- if not skip_cutoff and self.context_length_control == ContextLengthControlType.Latest:
152
- messages = cutoff_messages(
153
- messages,
154
- max_count=self.model_setting.context_length,
155
- backend=self.BACKEND_NAME,
156
- model=self.model_setting.id,
157
- )
158
-
159
- tools_params = {}
160
- if tools:
161
- tools_params = {"tools": [{"function_declarations": [tool["function"] for tool in tools]}]}
162
-
163
- response_format_params = {}
164
- if response_format is not None:
165
- if response_format.get("type") == "json_object":
166
- response_format_params = {"response_mime_type": "application/json"}
167
-
168
- top_p_params = {}
169
- if top_p:
170
- top_p_params = {"top_p": top_p}
171
-
172
- temperature_params = {}
173
- if temperature:
174
- temperature_params = {"temperature": temperature}
175
-
176
- request_body = {
177
- "contents": messages,
178
- "safetySettings": [
179
- {
180
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
181
- "threshold": "BLOCK_ONLY_HIGH",
182
- }
183
- ],
184
- "generationConfig": {
185
- "maxOutputTokens": max_tokens,
186
- **temperature_params,
187
- **top_p_params,
188
- **response_format_params,
189
- },
190
- **tools_params,
191
- **kwargs,
192
- }
193
- if system_prompt:
194
- request_body["systemInstruction"] = {"parts": [{"text": system_prompt}]}
195
-
196
- headers = {"Content-Type": "application/json"}
197
-
198
- params = {"key": self.endpoint.api_key}
199
-
200
- if self.stream:
201
- url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:streamGenerateContent"
202
- params["alt"] = "sse"
203
-
204
- def generator():
205
- result = {"content": "", "tool_calls": [], "usage": {}}
206
- client = self.raw_client
207
- with client.stream("POST", url, headers=headers, params=params, json=request_body) as response:
208
- for chunk in response.iter_lines():
209
- message = {"content": "", "tool_calls": []}
210
- if not chunk.startswith("data:"):
211
- continue
212
- data = json.loads(chunk[5:])
213
- chunk_content = data["candidates"][0]["content"]["parts"][0]
214
- if "text" in chunk_content:
215
- message["content"] = chunk_content["text"]
216
- result["content"] += message["content"]
217
- elif "functionCall" in chunk_content:
218
- message["tool_calls"] = [
219
- {
220
- "index": 0,
221
- "id": "call_0",
222
- "function": {
223
- "arguments": json.dumps(
224
- chunk_content["functionCall"]["args"], ensure_ascii=False
225
- ),
226
- "name": chunk_content["functionCall"]["name"],
227
- },
228
- "type": "function",
229
- }
230
- ]
231
-
232
- result["usage"] = message["usage"] = {
233
- "prompt_tokens": data["usageMetadata"].get("promptTokenCount", 0),
234
- "completion_tokens": data["usageMetadata"].get("candidatesTokenCount", 0),
235
- "total_tokens": data["usageMetadata"].get("totalTokenCount", 0),
236
- }
237
- yield ChatCompletionDeltaMessage(**message)
238
-
239
- return generator()
240
- else:
241
- url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:generateContent"
242
- client = self.raw_client
243
- response = client.post(url, json=request_body, headers=headers, params=params, timeout=None).json()
244
- if "error" in response:
245
- raise Exception(response["error"])
246
- result = {
247
- "content": "",
248
- "usage": {
249
- "prompt_tokens": response.get("usageMetadata", {}).get("promptTokenCount", 0),
250
- "completion_tokens": response.get("usageMetadata", {}).get("candidatesTokenCount", 0),
251
- "total_tokens": response.get("usageMetadata", {}).get("totalTokenCount", 0),
252
- },
253
- }
254
- tool_calls = []
255
- for part in response["candidates"][0]["content"]["parts"]:
256
- if "text" in part:
257
- result["content"] += part["text"]
258
- elif "functionCall" in part:
259
- tool_call = {
260
- "index": 0,
261
- "id": "call_0",
262
- "function": {
263
- "arguments": json.dumps(part["functionCall"]["args"], ensure_ascii=False),
264
- "name": part["functionCall"]["name"],
265
- },
266
- "type": "function",
267
- }
268
- tool_calls.append(tool_call)
269
-
270
- if tool_calls:
271
- result["tool_calls"] = tool_calls
272
-
273
- return ChatCompletionMessage(**result)
274
-
275
-
276
- class AsyncGeminiChatClient(BaseAsyncChatClient):
277
- DEFAULT_MODEL: str = defs.GEMINI_DEFAULT_MODEL
278
- BACKEND_NAME: BackendType = BackendType.Gemini
279
-
280
- def __init__(
281
- self,
282
- model: str = defs.GEMINI_DEFAULT_MODEL,
283
- stream: bool = True,
284
- temperature: float | None | NotGiven = NOT_GIVEN,
285
- context_length_control: ContextLengthControlType = defs.CONTEXT_LENGTH_CONTROL,
286
- random_endpoint: bool = True,
287
- endpoint_id: str = "",
288
- http_client: httpx.AsyncClient | None = None,
289
- backend_name: str | None = None,
290
- ):
291
- super().__init__(
292
- model,
293
- stream,
294
- temperature,
295
- context_length_control,
296
- random_endpoint,
297
- endpoint_id,
298
- http_client,
299
- backend_name,
300
- )
301
- self.model_id = None
302
- self.endpoint = None
303
-
304
- @cached_property
305
- def raw_client(self):
306
- self.endpoint, self.model_id = self._set_endpoint()
307
- if not self.http_client:
308
- self.http_client = httpx.AsyncClient(timeout=300, proxy=self.endpoint.proxy)
309
- return self.http_client
310
-
311
- @overload
312
- async def create_completion(
313
- self,
314
- *,
315
- messages: list,
316
- model: str | None = None,
317
- stream: Literal[False] = False,
318
- temperature: float | None | NotGiven = NOT_GIVEN,
319
- max_tokens: int | None = None,
320
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
321
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
322
- response_format: dict | None = None,
323
- stream_options: ChatCompletionStreamOptionsParam | None = None,
324
- top_p: float | NotGiven | None = NOT_GIVEN,
325
- skip_cutoff: bool = False,
326
- **kwargs,
327
- ) -> ChatCompletionMessage:
328
- pass
329
-
330
- @overload
331
- async def create_completion(
332
- self,
333
- *,
334
- messages: list,
335
- model: str | None = None,
336
- stream: Literal[True],
337
- temperature: float | None | NotGiven = NOT_GIVEN,
338
- max_tokens: int | None = None,
339
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
340
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
341
- response_format: dict | None = None,
342
- stream_options: ChatCompletionStreamOptionsParam | None = None,
343
- top_p: float | NotGiven | None = NOT_GIVEN,
344
- skip_cutoff: bool = False,
345
- **kwargs,
346
- ) -> AsyncGenerator[ChatCompletionDeltaMessage, Any]:
347
- pass
348
-
349
- @overload
350
- async def create_completion(
351
- self,
352
- *,
353
- messages: list,
354
- model: str | None = None,
355
- stream: bool,
356
- temperature: float | None | NotGiven = NOT_GIVEN,
357
- max_tokens: int | None = None,
358
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
359
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
360
- response_format: dict | None = None,
361
- stream_options: ChatCompletionStreamOptionsParam | None = None,
362
- top_p: float | NotGiven | None = NOT_GIVEN,
363
- skip_cutoff: bool = False,
364
- **kwargs,
365
- ) -> ChatCompletionMessage | AsyncGenerator[ChatCompletionDeltaMessage, Any]:
366
- pass
367
-
368
- async def create_completion(
369
- self,
370
- *,
371
- messages: list,
372
- model: str | None = None,
373
- stream: Literal[False] | Literal[True] = False,
374
- temperature: float | None | NotGiven = NOT_GIVEN,
375
- max_tokens: int | None = None,
376
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN,
377
- tool_choice: ToolChoice | NotGiven = NOT_GIVEN,
378
- response_format: dict | None = None,
379
- stream_options: ChatCompletionStreamOptionsParam | None = None,
380
- top_p: float | NotGiven | None = NOT_GIVEN,
381
- skip_cutoff: bool = False,
382
- **kwargs,
383
- ):
384
- if model is not None:
385
- self.model = model
386
- if stream is not None:
387
- self.stream = stream
388
- if temperature is not None:
389
- self.temperature = temperature
390
-
391
- self.model_setting = self.backend_settings.models[self.model]
392
- if self.model_id is None:
393
- self.model_id = self.model_setting.id
394
-
395
- self.endpoint, self.model_id = self._set_endpoint()
396
-
397
- if messages[0].get("role") == "system":
398
- system_prompt = messages[0]["content"]
399
- messages = messages[1:]
400
- else:
401
- system_prompt = ""
402
-
403
- if not skip_cutoff and self.context_length_control == ContextLengthControlType.Latest:
404
- messages = cutoff_messages(
405
- messages,
406
- max_count=self.model_setting.context_length,
407
- backend=self.BACKEND_NAME,
408
- model=self.model_setting.id,
409
- )
410
-
411
- tools_params = {}
412
- if tools:
413
- tools_params = {"tools": [{"function_declarations": [tool["function"] for tool in tools]}]}
414
-
415
- response_format_params = {}
416
- if response_format is not None:
417
- if response_format.get("type") == "json_object":
418
- response_format_params = {"response_mime_type": "application/json"}
419
-
420
- top_p_params = {}
421
- if top_p:
422
- top_p_params = {"top_p": top_p}
423
-
424
- temperature_params = {}
425
- if temperature:
426
- temperature_params = {"temperature": temperature}
427
-
428
- request_body = {
429
- "contents": messages,
430
- "safetySettings": [
431
- {
432
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
433
- "threshold": "BLOCK_ONLY_HIGH",
434
- }
435
- ],
436
- "generationConfig": {
437
- "maxOutputTokens": max_tokens,
438
- **temperature_params,
439
- **top_p_params,
440
- **response_format_params,
441
- },
442
- **tools_params,
443
- **kwargs,
444
- }
445
- if system_prompt:
446
- request_body["systemInstruction"] = {"parts": [{"text": system_prompt}]}
447
-
448
- headers = {"Content-Type": "application/json"}
449
-
450
- params = {"key": self.endpoint.api_key}
451
-
452
- if self.stream:
453
- url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:streamGenerateContent"
454
- params["alt"] = "sse"
455
-
456
- async def generator():
457
- result = {"content": "", "tool_calls": [], "usage": {}}
458
- client = self.raw_client
459
- async with client.stream("POST", url, headers=headers, params=params, json=request_body) as response:
460
- async for chunk in response.aiter_lines():
461
- message = {"content": "", "tool_calls": []}
462
- if not chunk.startswith("data:"):
463
- continue
464
- data = json.loads(chunk[5:])
465
- chunk_content = data["candidates"][0]["content"]["parts"][0]
466
- if "text" in chunk_content:
467
- message["content"] = chunk_content["text"]
468
- result["content"] += message["content"]
469
- elif "functionCall" in chunk_content:
470
- message["tool_calls"] = [
471
- {
472
- "index": 0,
473
- "id": "call_0",
474
- "function": {
475
- "arguments": json.dumps(
476
- chunk_content["functionCall"]["args"], ensure_ascii=False
477
- ),
478
- "name": chunk_content["functionCall"]["name"],
479
- },
480
- "type": "function",
481
- }
482
- ]
483
-
484
- result["usage"] = message["usage"] = {
485
- "prompt_tokens": data["usageMetadata"].get("promptTokenCount", 0),
486
- "completion_tokens": data["usageMetadata"].get("candidatesTokenCount", 0),
487
- "total_tokens": data["usageMetadata"].get("totalTokenCount", 0),
488
- }
489
- yield ChatCompletionDeltaMessage(**message)
490
-
491
- return generator()
492
- else:
493
- url = f"{self.endpoint.api_base}/models/{self.model_setting.id}:generateContent"
494
- client = self.raw_client
495
- async with client:
496
- response = await client.post(url, json=request_body, headers=headers, params=params, timeout=None)
497
- response = response.json()
498
- if "error" in response:
499
- raise Exception(response["error"])
500
- result = {
501
- "content": "",
502
- "usage": {
503
- "prompt_tokens": response.get("usageMetadata", {}).get("promptTokenCount", 0),
504
- "completion_tokens": response.get("usageMetadata", {}).get("candidatesTokenCount", 0),
505
- "total_tokens": response.get("usageMetadata", {}).get("totalTokenCount", 0),
506
- },
507
- }
508
- tool_calls = []
509
- for part in response["candidates"][0]["content"]["parts"]:
510
- if "text" in part:
511
- result["content"] += part["text"]
512
- elif "functionCall" in part:
513
- tool_call = {
514
- "index": 0,
515
- "id": "call_0",
516
- "function": {
517
- "arguments": json.dumps(part["functionCall"]["args"], ensure_ascii=False),
518
- "name": part["functionCall"]["name"],
519
- },
520
- "type": "function",
521
- }
522
- tool_calls.append(tool_call)
523
-
524
- if tool_calls:
525
- result["tool_calls"] = tool_calls
526
-
527
- return ChatCompletionMessage(**result)
11
+ class AsyncGeminiChatClient(AsyncOpenAICompatibleChatClient):
12
+ DEFAULT_MODEL = GEMINI_DEFAULT_MODEL
13
+ BACKEND_NAME = BackendType.Gemini
@@ -212,6 +212,8 @@ class OpenAICompatibleChatClient(BaseChatClient):
212
212
  else:
213
213
  _stream_options_params = {}
214
214
 
215
+ self._acquire_rate_limit(self.endpoint, self.model, messages)
216
+
215
217
  if self.stream:
216
218
  stream_response = raw_client.chat.completions.create(
217
219
  model=self.model_id,
@@ -282,9 +284,9 @@ class OpenAICompatibleChatClient(BaseChatClient):
282
284
  buffer = ""
283
285
  break
284
286
 
285
- message["content"] = "".join(current_content).strip()
287
+ message["content"] = "".join(current_content)
286
288
  if current_reasoning:
287
- message["reasoning_content"] = "".join(current_reasoning).strip()
289
+ message["reasoning_content"] = "".join(current_reasoning)
288
290
  current_content.clear()
289
291
  current_reasoning.clear()
290
292
 
@@ -307,8 +309,8 @@ class OpenAICompatibleChatClient(BaseChatClient):
307
309
  else:
308
310
  current_content.append(buffer)
309
311
  final_message = {
310
- "content": "".join(current_content).strip(),
311
- "reasoning_content": "".join(current_reasoning).strip() if current_reasoning else None,
312
+ "content": "".join(current_content),
313
+ "reasoning_content": "".join(current_reasoning) if current_reasoning else None,
312
314
  }
313
315
  yield ChatCompletionDeltaMessage(**final_message, usage=usage)
314
316
 
@@ -338,8 +340,8 @@ class OpenAICompatibleChatClient(BaseChatClient):
338
340
  if not result["reasoning_content"] and result["content"]:
339
341
  think_match = re.search(r"<think>(.*?)</think>", result["content"], re.DOTALL)
340
342
  if think_match:
341
- result["reasoning_content"] = think_match.group(1).strip()
342
- result["content"] = result["content"].replace(think_match.group(0), "", 1).strip()
343
+ result["reasoning_content"] = think_match.group(1)
344
+ result["content"] = result["content"].replace(think_match.group(0), "", 1)
343
345
 
344
346
  if tools:
345
347
  if self.model_setting.function_call_available and response.choices[0].message.tool_calls:
@@ -538,6 +540,8 @@ class AsyncOpenAICompatibleChatClient(BaseAsyncChatClient):
538
540
  else:
539
541
  max_tokens = self.model_setting.context_length - token_counts - 64
540
542
 
543
+ await self._acquire_rate_limit(self.endpoint, self.model, messages)
544
+
541
545
  if self.stream:
542
546
  stream_response = await raw_client.chat.completions.create(
543
547
  model=self.model_id,
@@ -608,9 +612,9 @@ class AsyncOpenAICompatibleChatClient(BaseAsyncChatClient):
608
612
  buffer = ""
609
613
  break
610
614
 
611
- message["content"] = "".join(current_content).strip()
615
+ message["content"] = "".join(current_content)
612
616
  if current_reasoning:
613
- message["reasoning_content"] = "".join(current_reasoning).strip()
617
+ message["reasoning_content"] = "".join(current_reasoning)
614
618
  current_content.clear()
615
619
  current_reasoning.clear()
616
620
 
@@ -633,8 +637,8 @@ class AsyncOpenAICompatibleChatClient(BaseAsyncChatClient):
633
637
  else:
634
638
  current_content.append(buffer)
635
639
  final_message = {
636
- "content": "".join(current_content).strip(),
637
- "reasoning_content": "".join(current_reasoning).strip() if current_reasoning else None,
640
+ "content": "".join(current_content),
641
+ "reasoning_content": "".join(current_reasoning) if current_reasoning else None,
638
642
  }
639
643
  yield ChatCompletionDeltaMessage(**final_message, usage=usage)
640
644
 
@@ -663,8 +667,8 @@ class AsyncOpenAICompatibleChatClient(BaseAsyncChatClient):
663
667
  if not result["reasoning_content"] and result["content"]:
664
668
  think_match = re.search(r"<think>(.*?)</think>", result["content"], re.DOTALL)
665
669
  if think_match:
666
- result["reasoning_content"] = think_match.group(1).strip()
667
- result["content"] = result["content"].replace(think_match.group(0), "", 1).strip()
670
+ result["reasoning_content"] = think_match.group(1)
671
+ result["content"] = result["content"].replace(think_match.group(0), "", 1)
668
672
 
669
673
  if tools:
670
674
  if self.model_setting.function_call_available and response.choices[0].message.tool_calls: