bisheng-langchain 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. bisheng_langchain/__init__.py +0 -0
  2. bisheng_langchain/chains/__init__.py +5 -0
  3. bisheng_langchain/chains/combine_documents/__init__.py +0 -0
  4. bisheng_langchain/chains/combine_documents/stuff.py +56 -0
  5. bisheng_langchain/chains/question_answering/__init__.py +240 -0
  6. bisheng_langchain/chains/retrieval_qa/__init__.py +0 -0
  7. bisheng_langchain/chains/retrieval_qa/base.py +89 -0
  8. bisheng_langchain/chat_models/__init__.py +11 -0
  9. bisheng_langchain/chat_models/host_llm.py +409 -0
  10. bisheng_langchain/chat_models/interface/__init__.py +10 -0
  11. bisheng_langchain/chat_models/interface/minimax.py +123 -0
  12. bisheng_langchain/chat_models/interface/openai.py +68 -0
  13. bisheng_langchain/chat_models/interface/types.py +61 -0
  14. bisheng_langchain/chat_models/interface/utils.py +5 -0
  15. bisheng_langchain/chat_models/interface/wenxin.py +114 -0
  16. bisheng_langchain/chat_models/interface/xunfei.py +233 -0
  17. bisheng_langchain/chat_models/interface/zhipuai.py +81 -0
  18. bisheng_langchain/chat_models/minimax.py +354 -0
  19. bisheng_langchain/chat_models/proxy_llm.py +354 -0
  20. bisheng_langchain/chat_models/wenxin.py +349 -0
  21. bisheng_langchain/chat_models/xunfeiai.py +355 -0
  22. bisheng_langchain/chat_models/zhipuai.py +379 -0
  23. bisheng_langchain/document_loaders/__init__.py +3 -0
  24. bisheng_langchain/document_loaders/elem_html.py +0 -0
  25. bisheng_langchain/document_loaders/elem_image.py +0 -0
  26. bisheng_langchain/document_loaders/elem_pdf.py +655 -0
  27. bisheng_langchain/document_loaders/parsers/__init__.py +5 -0
  28. bisheng_langchain/document_loaders/parsers/image.py +28 -0
  29. bisheng_langchain/document_loaders/parsers/test_image.py +286 -0
  30. bisheng_langchain/embeddings/__init__.py +7 -0
  31. bisheng_langchain/embeddings/host_embedding.py +133 -0
  32. bisheng_langchain/embeddings/interface/__init__.py +3 -0
  33. bisheng_langchain/embeddings/interface/types.py +23 -0
  34. bisheng_langchain/embeddings/interface/wenxin.py +86 -0
  35. bisheng_langchain/embeddings/wenxin.py +139 -0
  36. bisheng_langchain/vectorstores/__init__.py +3 -0
  37. bisheng_langchain/vectorstores/elastic_keywords_search.py +284 -0
  38. bisheng_langchain-0.0.1.dist-info/METADATA +64 -0
  39. bisheng_langchain-0.0.1.dist-info/RECORD +41 -0
  40. bisheng_langchain-0.0.1.dist-info/WHEEL +5 -0
  41. bisheng_langchain-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,409 @@
1
+ """proxy llm chat wrapper."""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ import sys
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
7
+
8
+ import requests
9
+ from langchain.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
10
+ from langchain.chat_models.base import BaseChatModel
11
+ from langchain.schema import ChatGeneration, ChatResult
12
+ from langchain.schema.messages import (AIMessage, BaseMessage, ChatMessage, FunctionMessage,
13
+ HumanMessage, SystemMessage)
14
+ from langchain.utils import get_from_dict_or_env
15
+ from pydantic import Field, root_validator
16
+ # from requests.exceptions import HTTPError
17
+ from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_attempt,
18
+ wait_exponential)
19
+
20
+ # from .interface import MinimaxChatCompletion
21
+ # from .interface.types import ChatInput
22
+
23
+ if TYPE_CHECKING:
24
+ import tiktoken
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def _import_tiktoken() -> Any:
30
+ try:
31
+ import tiktoken
32
+ except ImportError:
33
+ raise ValueError('Could not import tiktoken python package. '
34
+ 'This is needed in order to calculate get_token_ids. '
35
+ 'Please install it with `pip install tiktoken`.')
36
+ return tiktoken
37
+
38
+
39
+ def _create_retry_decorator(llm: BaseHostChatLLM) -> Callable[[Any], Any]:
40
+
41
+ min_seconds = 1
42
+ max_seconds = 20
43
+ # Wait 2^x * 1 second between each retry starting with
44
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
45
+ return retry(
46
+ reraise=True,
47
+ stop=stop_after_attempt(llm.max_retries),
48
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
49
+ retry=(retry_if_exception_type(Exception)),
50
+ before_sleep=before_sleep_log(logger, logging.WARNING),
51
+ )
52
+
53
+
54
+ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
55
+ role = _dict['role']
56
+ if role == 'user':
57
+ return HumanMessage(content=_dict['content'])
58
+ elif role == 'assistant':
59
+ content = _dict[
60
+ 'content'] or '' # OpenAI returns None for tool invocations
61
+ if _dict.get('function_call'):
62
+ additional_kwargs = {'function_call': dict(_dict['function_call'])}
63
+ else:
64
+ additional_kwargs = {}
65
+ return AIMessage(content=content, additional_kwargs=additional_kwargs)
66
+ elif role == 'system':
67
+ return SystemMessage(content=_dict['content'])
68
+ elif role == 'function':
69
+ return FunctionMessage(content=_dict['content'], name=_dict['name'])
70
+ else:
71
+ return ChatMessage(content=_dict['content'], role=role)
72
+
73
+
74
+ def _convert_message_to_dict(message: BaseMessage) -> dict:
75
+ if isinstance(message, ChatMessage):
76
+ message_dict = {'role': message.role, 'content': message.content}
77
+ elif isinstance(message, HumanMessage):
78
+ message_dict = {'role': 'user', 'content': message.content}
79
+ elif isinstance(message, AIMessage):
80
+ message_dict = {'role': 'assistant', 'content': message.content}
81
+ if 'function_call' in message.additional_kwargs:
82
+ message_dict['function_call'] = message.additional_kwargs[
83
+ 'function_call']
84
+ elif isinstance(message, SystemMessage):
85
+ message_dict = {'role': 'system', 'content': message.content}
86
+ elif isinstance(message, FunctionMessage):
87
+ message_dict = {
88
+ 'role': 'function',
89
+ 'content': message.content,
90
+ 'name': message.name,
91
+ }
92
+ else:
93
+ raise ValueError(f'Got unknown type {message}')
94
+ if 'name' in message.additional_kwargs:
95
+ message_dict['name'] = message.additional_kwargs['name']
96
+ return message_dict
97
+
98
+
99
+ class BaseHostChatLLM(BaseChatModel):
100
+ """Wrapper around base host Chat large language models.
101
+ """
102
+
103
+ client: Optional[Any] #: :meta private:
104
+ """Model name to use."""
105
+ model_name: str = Field('', alias='model')
106
+
107
+ temperature: float = 0.9
108
+ top_p: float = 0.95
109
+ do_sample: bool = False
110
+ """Number of chat completions to generate for each prompt."""
111
+ max_tokens: int = 4096
112
+ """What sampling temperature to use."""
113
+ model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
114
+ """Holds any model parameters valid for `create` call not explicitly specified."""
115
+ host_base_url: Optional[str] = None
116
+
117
+ headers: Optional[Dict[str, str]] = Field(default_factory=dict)
118
+
119
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None
120
+ """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
121
+ max_retries: Optional[int] = 6
122
+ """Maximum number of retries to make when generating."""
123
+ streaming: Optional[bool] = False
124
+ """Whether to stream the results or not."""
125
+ n: Optional[int] = 1
126
+ """Maximum number of tokens to generate."""
127
+ tiktoken_model_name: Optional[str] = None
128
+ """The model name to pass to tiktoken when using this class.
129
+ Tiktoken is used to count the number of tokens in documents to constrain
130
+ them to be under a certain limit. By default, when set to None, this will
131
+ be the same as the embedding model name. However, there are some cases
132
+ where you may want to use this Embedding class with a model name not
133
+ supported by tiktoken. This can include when using Azure embeddings or
134
+ when using one of the many model providers that expose an OpenAI-like
135
+ API but with different models. In those cases, in order to avoid erroring
136
+ when tiktoken is called, you can specify a model name to use here."""
137
+
138
+ verbose: Optional[bool] = False
139
+
140
+ class Config:
141
+ """Configuration for this pydantic object."""
142
+
143
+ allow_population_by_field_name = True
144
+
145
+ @root_validator()
146
+ def validate_environment(cls, values: Dict) -> Dict:
147
+ """Validate that api key and python package exists in environment."""
148
+ values['host_base_url'] = get_from_dict_or_env(values, 'host_base_url',
149
+ 'HostBaseUrl')
150
+ try:
151
+ values['client'] = requests.post
152
+ except AttributeError:
153
+ raise ValueError(
154
+ 'Try upgrading it with `pip install --upgrade requests`.')
155
+ return values
156
+
157
+ @property
158
+ def _default_params(self) -> Dict[str, Any]:
159
+ """Get the default parameters for calling ChatMinimaxAI API."""
160
+ return {
161
+ 'model': self.model_name,
162
+ 'temperature': self.temperature,
163
+ 'top_p': self.top_p,
164
+ 'max_tokens': self.max_tokens,
165
+ 'do_sample': self.do_sample,
166
+ **self.model_kwargs,
167
+ }
168
+
169
+ def completion_with_retry(self, **kwargs: Any) -> Any:
170
+ retry_decorator = _create_retry_decorator(self)
171
+
172
+ @retry_decorator
173
+ def _completion_with_retry(**kwargs: Any) -> Any:
174
+ messages = kwargs.get('messages')
175
+ temperature = kwargs.get('temperature')
176
+ top_p = kwargs.get('top_p')
177
+ max_tokens = kwargs.get('max_tokens')
178
+ do_sample = kwargs.get('do_sample')
179
+ params = {
180
+ 'messages': messages,
181
+ 'model': self.model_name,
182
+ 'top_p': top_p,
183
+ 'temperature': temperature,
184
+ 'max_tokens': max_tokens,
185
+ 'do_sample': do_sample
186
+ }
187
+
188
+ if self.verbose:
189
+ print('payload', params)
190
+
191
+ url = f'{self.host_base_url}/{self.model_name}/infer'
192
+ resp = self.client(url=url, json=params).json()
193
+ if resp['status_code'] != 200:
194
+ raise ValueError(
195
+ f"API returned an error: {resp['status_message']}")
196
+ resp['usage'] = {}
197
+ return resp
198
+
199
+ return _completion_with_retry(**kwargs)
200
+
201
+ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
202
+ overall_token_usage: dict = {}
203
+ for output in llm_outputs:
204
+ if output is None:
205
+ # Happens in streaming
206
+ continue
207
+ token_usage = output['token_usage']
208
+ if token_usage is None:
209
+ continue
210
+
211
+ for k, v in token_usage.items():
212
+ if k in overall_token_usage:
213
+ overall_token_usage[k] += v
214
+ else:
215
+ overall_token_usage[k] = v
216
+ return {
217
+ 'token_usage': overall_token_usage,
218
+ 'model_name': self.model_name
219
+ }
220
+
221
+ def _generate(
222
+ self,
223
+ messages: List[BaseMessage],
224
+ stop: Optional[List[str]] = None,
225
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
226
+ **kwargs: Any,
227
+ ) -> ChatResult:
228
+ message_dicts, params = self._create_message_dicts(messages, stop)
229
+ params = {**params, **kwargs}
230
+ response = self.completion_with_retry(messages=message_dicts, **params)
231
+ return self._create_chat_result(response)
232
+
233
+ async def _agenerate(
234
+ self,
235
+ messages: List[BaseMessage],
236
+ stop: Optional[List[str]] = None,
237
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
238
+ **kwargs: Any,
239
+ ) -> ChatResult:
240
+ return self._generate(messages, stop, run_manager, kwargs)
241
+
242
+ def _create_message_dicts(
243
+ self, messages: List[BaseMessage], stop: Optional[List[str]]
244
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
245
+ params = dict(self._client_params)
246
+ if stop is not None:
247
+ if 'stop' in params:
248
+ raise ValueError(
249
+ '`stop` found in both the input and default params.')
250
+ params['stop'] = stop
251
+
252
+ message_dicts = [_convert_message_to_dict(m) for m in messages]
253
+
254
+ return message_dicts, params
255
+
256
+ def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
257
+ generations = []
258
+ for res in response['choices']:
259
+ message = _convert_dict_to_message(res['message'])
260
+ gen = ChatGeneration(message=message)
261
+ generations.append(gen)
262
+
263
+ llm_output = {
264
+ 'token_usage': response['usage'],
265
+ 'model_name': self.model_name
266
+ }
267
+ return ChatResult(generations=generations, llm_output=llm_output)
268
+
269
+ @property
270
+ def _identifying_params(self) -> Mapping[str, Any]:
271
+ """Get the identifying parameters."""
272
+ return {**{'model_name': self.model_name}, **self._default_params}
273
+
274
+ @property
275
+ def _client_params(self) -> Mapping[str, Any]:
276
+ """Get the parameters used for the client."""
277
+ minimaxai_creds: Dict[str, Any] = {
278
+ 'model': self.model_name,
279
+ }
280
+ return {**minimaxai_creds, **self._default_params}
281
+
282
+ def _get_invocation_params(self,
283
+ stop: Optional[List[str]] = None,
284
+ **kwargs: Any) -> Dict[str, Any]:
285
+ """Get the parameters used to invoke the model FOR THE CALLBACKS."""
286
+ return {
287
+ **super()._get_invocation_params(stop=stop, **kwargs),
288
+ **self._default_params,
289
+ 'model': self.model_name,
290
+ 'function': kwargs.get('functions'),
291
+ }
292
+
293
+ @property
294
+ def _llm_type(self) -> str:
295
+ """Return type of chat model."""
296
+ return 'host_chat_llm'
297
+
298
+ def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
299
+ tiktoken_ = _import_tiktoken()
300
+ if self.tiktoken_model_name is not None:
301
+ model = self.tiktoken_model_name
302
+ else:
303
+ model = self.model_name
304
+ # model chatglm-std, chatglm-lite
305
+ # Returns the number of tokens used by a list of messages.
306
+ try:
307
+ encoding = tiktoken_.encoding_for_model(model)
308
+ except KeyError:
309
+ logger.warning(
310
+ 'Warning: model not found. Using cl100k_base encoding.')
311
+ model = 'cl100k_base'
312
+ encoding = tiktoken_.get_encoding(model)
313
+ return model, encoding
314
+
315
+ def get_token_ids(self, text: str) -> List[int]:
316
+ """Get the tokens present in the text with tiktoken package."""
317
+ # tiktoken NOT supported for Python 3.7 or below
318
+ if sys.version_info[1] <= 7:
319
+ return super().get_token_ids(text)
320
+ _, encoding_model = self._get_encoding_model()
321
+ return encoding_model.encode(text)
322
+
323
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
324
+ """Calculate num tokens for chatglm with tiktoken package.
325
+
326
+ todo: read chatglm document
327
+ Official documentation: https://github.com/openai/openai-cookbook/blob/
328
+ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
329
+ if sys.version_info[1] <= 7:
330
+ return super().get_num_tokens_from_messages(messages)
331
+ model, encoding = self._get_encoding_model()
332
+ if model.startswith('chatglm'):
333
+ # every message follows <im_start>{role/name}\n{content}<im_end>\n
334
+ tokens_per_message = 4
335
+ # if there's a name, the role is omitted
336
+ tokens_per_name = -1
337
+ else:
338
+ raise NotImplementedError(
339
+ f'get_num_tokens_from_messages() is not presently implemented '
340
+ f'for model {model}.'
341
+ 'See https://github.com/openai/openai-python/blob/main/chatml.md for '
342
+ 'information on how messages are converted to tokens.')
343
+ num_tokens = 0
344
+ messages_dict = [_convert_message_to_dict(m) for m in messages]
345
+ for message in messages_dict:
346
+ num_tokens += tokens_per_message
347
+ for key, value in message.items():
348
+ num_tokens += len(encoding.encode(value))
349
+ if key == 'name':
350
+ num_tokens += tokens_per_name
351
+ # every reply is primed with <im_start>assistant
352
+ num_tokens += 3
353
+ return num_tokens
354
+
355
+
356
+ class ChatGLM2Host(BaseHostChatLLM):
357
+ # chatglm2-12b, chatglm2-6b
358
+ model_name: str = Field('chatglm2-6b', alias='model')
359
+
360
+ temperature: float = 0.95
361
+ top_p: float = 0.7
362
+ max_tokens: int = 4096
363
+
364
+ @property
365
+ def _llm_type(self) -> str:
366
+ """Return type of chat model."""
367
+ return 'chatglm2'
368
+
369
+
370
+ class BaichuanChat(BaseHostChatLLM):
371
+ # Baichuan-7B-Chat, Baichuan-13B-Chat
372
+ model_name: str = Field('Baichuan-13B-Chat', alias='model')
373
+
374
+ temperature: float = 0.3
375
+ top_p: float = 0.85
376
+ max_tokens: int = 8192
377
+
378
+ @property
379
+ def _llm_type(self) -> str:
380
+ """Return type of chat model."""
381
+ return 'baichang_chat'
382
+
383
+
384
+ class QwenChat(BaseHostChatLLM):
385
+ # Qwen-7B-Chat
386
+ model_name: str = Field('Qwen-7B-Chat', alias='model')
387
+
388
+ temperature: float = 0
389
+ top_p: float = 0.5
390
+ max_tokens: int = 8192
391
+
392
+ @property
393
+ def _llm_type(self) -> str:
394
+ """Return type of chat model."""
395
+ return 'qwen_chat'
396
+
397
+
398
+ class Llama2Chat(BaseHostChatLLM):
399
+ # Llama-2-7b-chat-hf, Llama-2-13b-chat-hf, Llama-2-70b-chat-hf
400
+ model_name: str = Field('Llama-2-7b-chat-hf', alias='model')
401
+
402
+ temperature: float = 0.9
403
+ top_p: float = 0.6
404
+ max_tokens: int = 8192
405
+
406
+ @property
407
+ def _llm_type(self) -> str:
408
+ """Return type of chat model."""
409
+ return 'llama2_chat'
@@ -0,0 +1,10 @@
1
+ from .minimax import ChatCompletion as MinimaxChatCompletion
2
+ from .openai import ChatCompletion as OpenaiChatCompletion
3
+ from .wenxin import ChatCompletion as WenxinChatCompletion
4
+ from .xunfei import ChatCompletion as XunfeiChatCompletion
5
+ from .zhipuai import ChatCompletion as ZhipuaiChatCompletion
6
+
7
+ __all__ = [
8
+ 'MinimaxChatCompletion', 'OpenaiChatCompletion', 'WenxinChatCompletion',
9
+ 'XunfeiChatCompletion', 'ZhipuaiChatCompletion'
10
+ ]
@@ -0,0 +1,123 @@
1
+ import json
2
+
3
+ import requests
4
+
5
+ from .types import ChatInput, ChatOutput, Choice, Message, Usage
6
+ from .utils import get_ts
7
+
8
+
9
+ class ChatCompletion(object):
10
+
11
+ def __init__(self, group_id, api_key, **kwargs):
12
+ ep_url = 'https://api.minimax.chat/v1/text/chatcompletion'
13
+ self.endpoint = f'{ep_url}?GroupId={group_id}'
14
+ self.headers = {
15
+ 'Authorization': f'Bearer {api_key}',
16
+ 'Content-Type': 'application/json'
17
+ }
18
+
19
+ def parseChunkDelta(self, chunk):
20
+ decoded_data = chunk.decode('utf-8')
21
+ parsed_data = json.loads(decoded_data[6:])
22
+ delta_content = parsed_data['choices'][0]['delta']
23
+ return delta_content
24
+
25
+ def __call__(self, inp: ChatInput, verbose=False):
26
+ messages = inp.messages
27
+ model = inp.model
28
+ top_p = 0.95 if inp.top_p is None else inp.top_p
29
+ temperature = 0.9 if inp.temperature is None else inp.temperature
30
+ stream = False if inp.stream is None else inp.stream
31
+ max_tokens = 1024 if inp.max_tokens is None else inp.max_tokens
32
+ if abs(temperature) <= 1e-6:
33
+ temperature = 1e-6
34
+
35
+ chat_messages = messages
36
+ system_prompt = ('MM智能助理是一款由MinMax自研的,没有调用其他产品接口的大型语言'
37
+ '模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。\n----\n')
38
+
39
+ if messages[0].role == 'system':
40
+ system_prompt = messages[0].content
41
+ chat_messages = messages[1:]
42
+
43
+ new_messages = []
44
+ for m in chat_messages:
45
+ role = 'USER'
46
+ if m.role == 'system' or m.role == 'assistant':
47
+ role = 'BOT'
48
+
49
+ new_messages.append({'sender_type': role, 'text': m.content})
50
+
51
+ # role_meta is given, prompt must is not empty
52
+ system_info = {}
53
+ if system_prompt:
54
+ system_info = {
55
+ 'prompt': system_prompt,
56
+ 'role_meta': {
57
+ 'user_name': '用户',
58
+ 'bot_name': 'MM智能助理'
59
+ }
60
+ }
61
+
62
+ payload = {
63
+ 'model': model,
64
+ 'stream': stream,
65
+ 'use_standard_sse': True,
66
+ 'messages': new_messages,
67
+ 'temperature': temperature,
68
+ 'top_p': top_p,
69
+ 'tokens_to_generate': max_tokens
70
+ }
71
+ payload.update(system_info)
72
+
73
+ if verbose:
74
+ print('payload', payload)
75
+
76
+ response = requests.post(self.endpoint,
77
+ headers=self.headers,
78
+ json=payload)
79
+
80
+ req_type = 'chat.completion'
81
+ status_message = 'success'
82
+ status_code = response.status_code
83
+ created = get_ts()
84
+ choices = []
85
+ usage = None
86
+ if status_code == 200:
87
+ try:
88
+ info = json.loads(response.text)
89
+ if info['base_resp']['status_code'] == 0:
90
+ created = info['created']
91
+ # reply = info['reply']
92
+ choices = []
93
+ for s in info['choices']:
94
+ index = s['index']
95
+ finish_reason = s['finish_reason']
96
+ msg = Message(role='assistant', content=s['text'])
97
+ cho = Choice(index=index,
98
+ message=msg,
99
+ finish_reason=finish_reason)
100
+ choices.append(cho)
101
+ total_tokens = info['usage']['total_tokens']
102
+ usage = Usage(total_tokens=total_tokens)
103
+ else:
104
+ status_code = info['base_resp']['status_code']
105
+ status_message = info['base_resp']['status_msg']
106
+
107
+ except Exception as e:
108
+ status_code = 401
109
+ status_message = str(e)
110
+ else:
111
+ status_code = 400
112
+ status_message = 'requests error'
113
+
114
+ if status_code != 200:
115
+ raise Exception(status_message)
116
+
117
+ return ChatOutput(status_code=status_code,
118
+ status_message=status_message,
119
+ model=model,
120
+ object=req_type,
121
+ created=created,
122
+ choices=choices,
123
+ usage=usage)
@@ -0,0 +1,68 @@
1
+ # import json
2
+
3
+ import openai
4
+
5
+ from .types import ChatInput, ChatOutput, Choice, Usage
6
+ from .utils import get_ts
7
+
8
+
9
+ class ChatCompletion(object):
10
+
11
+ def __init__(self, api_key, proxy=None, **kwargs):
12
+ openai.api_key = api_key
13
+ openai.proxy = proxy
14
+
15
+ def __call__(self, inp: ChatInput, verbose=False):
16
+ messages = inp.messages
17
+ model = inp.model
18
+ top_p = 0.7 if inp.top_p is None else inp.top_p
19
+ temperature = 0.97 if inp.temperature is None else inp.temperature
20
+ # stream = False if inp.stream is None else inp.stream
21
+ max_tokens = 1024 if inp.max_tokens is None else inp.max_tokens
22
+ stop = None
23
+ if inp.stop is not None:
24
+ stop = inp.stop.split('||')
25
+
26
+ new_messages = [m.dict() for m in messages]
27
+ created = get_ts()
28
+ payload = {
29
+ 'model': model,
30
+ 'messages': new_messages,
31
+ 'temperature': temperature,
32
+ 'top_p': top_p,
33
+ 'stop': stop,
34
+ 'max_tokens': max_tokens,
35
+ }
36
+ if inp.functions:
37
+ payload.update({'functions': inp.functions})
38
+
39
+ if verbose:
40
+ print('payload', payload)
41
+
42
+ req_type = 'chat.completion'
43
+ status_message = 'success'
44
+ choices = []
45
+ usage = None
46
+ try:
47
+ resp = openai.ChatCompletion.create(**payload)
48
+ status_code = 200
49
+ choices = []
50
+ for choice in resp['choices']:
51
+ cho = Choice(**choice)
52
+ choices.append(cho)
53
+ usage = Usage(**resp['usage'])
54
+
55
+ except Exception as e:
56
+ status_code = 400
57
+ status_message = str(e)
58
+
59
+ if status_code != 200:
60
+ raise Exception(status_message)
61
+
62
+ return ChatOutput(status_code=status_code,
63
+ status_message=status_message,
64
+ model=model,
65
+ object=req_type,
66
+ created=created,
67
+ choices=choices,
68
+ usage=usage)
@@ -0,0 +1,61 @@
1
+ # from typing import Union
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class Message(BaseModel):
7
+ role: str
8
+ content: str
9
+
10
+
11
+ class Function(BaseModel):
12
+ name: str
13
+ description: str
14
+ parameters: dict
15
+
16
+
17
+ class ChatInput(BaseModel):
18
+ model: str
19
+ messages: list[Message] = []
20
+ top_p: float = None
21
+ temperature: float = None
22
+ n: int = 1
23
+ stream: bool = False
24
+ stop: str = None
25
+ max_tokens: int = 256
26
+ functions: list[Function] = []
27
+ function_call: str = None
28
+
29
+
30
+ class Choice(BaseModel):
31
+ index: int
32
+ message: Message = None
33
+ finish_reason: str = 'stop'
34
+
35
+
36
+ class Usage(BaseModel):
37
+ prompt_tokens: int = 0
38
+ completion_tokens: int = 0
39
+ total_tokens: int
40
+
41
+
42
+ class ChatOutput(BaseModel):
43
+ status_code: int
44
+ status_message: str = 'success'
45
+ id: str = None
46
+ object: str = None
47
+ model: str = None
48
+ created: int = None
49
+ choices: list[Choice] = []
50
+ usage: Usage = None
51
+
52
+
53
+ class CompletionsInput(BaseModel):
54
+ model: str
55
+ prompt: str
56
+ top_p: float = None
57
+ temperature: float = None
58
+ n: int = 1
59
+ stream: bool = True
60
+ stop: str = None
61
+ max_tokens: int = 256
@@ -0,0 +1,5 @@
1
+ import time
2
+
3
+
4
+ def get_ts():
5
+ return round(time.time() * 1000)