bisheng-langchain 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/__init__.py +0 -0
- bisheng_langchain/chains/__init__.py +5 -0
- bisheng_langchain/chains/combine_documents/__init__.py +0 -0
- bisheng_langchain/chains/combine_documents/stuff.py +56 -0
- bisheng_langchain/chains/question_answering/__init__.py +240 -0
- bisheng_langchain/chains/retrieval_qa/__init__.py +0 -0
- bisheng_langchain/chains/retrieval_qa/base.py +89 -0
- bisheng_langchain/chat_models/__init__.py +11 -0
- bisheng_langchain/chat_models/host_llm.py +409 -0
- bisheng_langchain/chat_models/interface/__init__.py +10 -0
- bisheng_langchain/chat_models/interface/minimax.py +123 -0
- bisheng_langchain/chat_models/interface/openai.py +68 -0
- bisheng_langchain/chat_models/interface/types.py +61 -0
- bisheng_langchain/chat_models/interface/utils.py +5 -0
- bisheng_langchain/chat_models/interface/wenxin.py +114 -0
- bisheng_langchain/chat_models/interface/xunfei.py +233 -0
- bisheng_langchain/chat_models/interface/zhipuai.py +81 -0
- bisheng_langchain/chat_models/minimax.py +354 -0
- bisheng_langchain/chat_models/proxy_llm.py +354 -0
- bisheng_langchain/chat_models/wenxin.py +349 -0
- bisheng_langchain/chat_models/xunfeiai.py +355 -0
- bisheng_langchain/chat_models/zhipuai.py +379 -0
- bisheng_langchain/document_loaders/__init__.py +3 -0
- bisheng_langchain/document_loaders/elem_html.py +0 -0
- bisheng_langchain/document_loaders/elem_image.py +0 -0
- bisheng_langchain/document_loaders/elem_pdf.py +655 -0
- bisheng_langchain/document_loaders/parsers/__init__.py +5 -0
- bisheng_langchain/document_loaders/parsers/image.py +28 -0
- bisheng_langchain/document_loaders/parsers/test_image.py +286 -0
- bisheng_langchain/embeddings/__init__.py +7 -0
- bisheng_langchain/embeddings/host_embedding.py +133 -0
- bisheng_langchain/embeddings/interface/__init__.py +3 -0
- bisheng_langchain/embeddings/interface/types.py +23 -0
- bisheng_langchain/embeddings/interface/wenxin.py +86 -0
- bisheng_langchain/embeddings/wenxin.py +139 -0
- bisheng_langchain/vectorstores/__init__.py +3 -0
- bisheng_langchain/vectorstores/elastic_keywords_search.py +284 -0
- bisheng_langchain-0.0.1.dist-info/METADATA +64 -0
- bisheng_langchain-0.0.1.dist-info/RECORD +41 -0
- bisheng_langchain-0.0.1.dist-info/WHEEL +5 -0
- bisheng_langchain-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,349 @@
|
|
1
|
+
"""proxy llm chat wrapper."""
|
2
|
+
from __future__ import annotations
|
3
|
+
|
4
|
+
import logging
|
5
|
+
import sys
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
7
|
+
|
8
|
+
# import requests
|
9
|
+
from langchain.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
|
10
|
+
from langchain.chat_models.base import BaseChatModel
|
11
|
+
from langchain.schema import ChatGeneration, ChatResult
|
12
|
+
from langchain.schema.messages import (AIMessage, BaseMessage, ChatMessage, FunctionMessage,
|
13
|
+
HumanMessage, SystemMessage)
|
14
|
+
from langchain.utils import get_from_dict_or_env
|
15
|
+
from pydantic import Field, root_validator
|
16
|
+
from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_attempt,
|
17
|
+
wait_exponential)
|
18
|
+
|
19
|
+
from .interface import WenxinChatCompletion
|
20
|
+
from .interface.types import ChatInput
|
21
|
+
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
import tiktoken
|
24
|
+
|
25
|
+
logger = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
|
28
|
+
def _import_tiktoken() -> Any:
|
29
|
+
try:
|
30
|
+
import tiktoken
|
31
|
+
except ImportError:
|
32
|
+
raise ValueError('Could not import tiktoken python package. '
|
33
|
+
'This is needed in order to calculate get_token_ids. '
|
34
|
+
'Please install it with `pip install tiktoken`.')
|
35
|
+
return tiktoken
|
36
|
+
|
37
|
+
|
38
|
+
def _create_retry_decorator(llm: ChatWenxin) -> Callable[[Any], Any]:
|
39
|
+
|
40
|
+
min_seconds = 1
|
41
|
+
max_seconds = 20
|
42
|
+
# Wait 2^x * 1 second between each retry starting with
|
43
|
+
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
44
|
+
return retry(
|
45
|
+
reraise=True,
|
46
|
+
stop=stop_after_attempt(llm.max_retries),
|
47
|
+
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
48
|
+
retry=(retry_if_exception_type(Exception)),
|
49
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
54
|
+
role = _dict['role']
|
55
|
+
if role == 'user':
|
56
|
+
return HumanMessage(content=_dict['content'])
|
57
|
+
elif role == 'assistant':
|
58
|
+
content = _dict[
|
59
|
+
'content'] or '' # OpenAI returns None for tool invocations
|
60
|
+
if _dict.get('function_call'):
|
61
|
+
additional_kwargs = {'function_call': dict(_dict['function_call'])}
|
62
|
+
else:
|
63
|
+
additional_kwargs = {}
|
64
|
+
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
65
|
+
elif role == 'system':
|
66
|
+
return SystemMessage(content=_dict['content'])
|
67
|
+
elif role == 'function':
|
68
|
+
return FunctionMessage(content=_dict['content'], name=_dict['name'])
|
69
|
+
else:
|
70
|
+
return ChatMessage(content=_dict['content'], role=role)
|
71
|
+
|
72
|
+
|
73
|
+
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
74
|
+
if isinstance(message, ChatMessage):
|
75
|
+
message_dict = {'role': message.role, 'content': message.content}
|
76
|
+
elif isinstance(message, HumanMessage):
|
77
|
+
message_dict = {'role': 'user', 'content': message.content}
|
78
|
+
elif isinstance(message, AIMessage):
|
79
|
+
message_dict = {'role': 'assistant', 'content': message.content}
|
80
|
+
if 'function_call' in message.additional_kwargs:
|
81
|
+
message_dict['function_call'] = message.additional_kwargs[
|
82
|
+
'function_call']
|
83
|
+
elif isinstance(message, SystemMessage):
|
84
|
+
message_dict = {'role': 'system', 'content': message.content}
|
85
|
+
elif isinstance(message, FunctionMessage):
|
86
|
+
message_dict = {
|
87
|
+
'role': 'function',
|
88
|
+
'content': message.content,
|
89
|
+
'name': message.name,
|
90
|
+
}
|
91
|
+
else:
|
92
|
+
raise ValueError(f'Got unknown type {message}')
|
93
|
+
if 'name' in message.additional_kwargs:
|
94
|
+
message_dict['name'] = message.additional_kwargs['name']
|
95
|
+
return message_dict
|
96
|
+
|
97
|
+
|
98
|
+
class ChatWenxin(BaseChatModel):
|
99
|
+
"""Wrapper around proxy Chat large language models.
|
100
|
+
|
101
|
+
To use, the environment variable ``ELEMAI_API_KEY`` set with your API key.
|
102
|
+
|
103
|
+
Example:
|
104
|
+
.. code-block:: python
|
105
|
+
|
106
|
+
from bisheng_langchain.chat_models import ChatWenxin
|
107
|
+
chat_miniamaxai = ChatWenxin(model_name="ernie-bot")
|
108
|
+
"""
|
109
|
+
|
110
|
+
client: Optional[Any] #: :meta private:
|
111
|
+
"""Model name to use."""
|
112
|
+
model_name: str = Field('ernie-bot', alias='model')
|
113
|
+
|
114
|
+
temperature: float = 0.95
|
115
|
+
top_p: float = 0.8
|
116
|
+
"""What sampling temperature to use."""
|
117
|
+
model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
118
|
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
119
|
+
wenxin_api_key: Optional[str] = None
|
120
|
+
wenxin_secret_key: Optional[str] = None
|
121
|
+
|
122
|
+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
123
|
+
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
124
|
+
max_retries: Optional[int] = 6
|
125
|
+
"""Maximum number of retries to make when generating."""
|
126
|
+
streaming: Optional[bool] = False
|
127
|
+
"""Whether to stream the results or not."""
|
128
|
+
n: Optional[int] = 1
|
129
|
+
"""Number of chat completions to generate for each prompt."""
|
130
|
+
max_tokens: Optional[int] = None
|
131
|
+
"""Maximum number of tokens to generate."""
|
132
|
+
tiktoken_model_name: Optional[str] = None
|
133
|
+
"""The model name to pass to tiktoken when using this class.
|
134
|
+
Tiktoken is used to count the number of tokens in documents to constrain
|
135
|
+
them to be under a certain limit. By default, when set to None, this will
|
136
|
+
be the same as the embedding model name. However, there are some cases
|
137
|
+
where you may want to use this Embedding class with a model name not
|
138
|
+
supported by tiktoken. This can include when using Azure embeddings or
|
139
|
+
when using one of the many model providers that expose an OpenAI-like
|
140
|
+
API but with different models. In those cases, in order to avoid erroring
|
141
|
+
when tiktoken is called, you can specify a model name to use here."""
|
142
|
+
verbose: Optional[bool] = False
|
143
|
+
|
144
|
+
class Config:
|
145
|
+
"""Configuration for this pydantic object."""
|
146
|
+
|
147
|
+
allow_population_by_field_name = True
|
148
|
+
|
149
|
+
@root_validator()
|
150
|
+
def validate_environment(cls, values: Dict) -> Dict:
|
151
|
+
"""Validate that api key and python package exists in environment."""
|
152
|
+
values['wenxin_api_key'] = get_from_dict_or_env(
|
153
|
+
values, 'wenxin_api_key', 'WENXIN_API_KEY')
|
154
|
+
|
155
|
+
values['wenxin_secret_key'] = get_from_dict_or_env(
|
156
|
+
values, 'wenxin_secret_key', 'WENXIN_SECRET_KEY')
|
157
|
+
|
158
|
+
api_key = values['wenxin_api_key']
|
159
|
+
secret_key = values['wenxin_secret_key']
|
160
|
+
try:
|
161
|
+
values['client'] = WenxinChatCompletion(api_key, secret_key)
|
162
|
+
except AttributeError:
|
163
|
+
raise ValueError(
|
164
|
+
'Try upgrading it with `pip install --upgrade requests`.')
|
165
|
+
return values
|
166
|
+
|
167
|
+
@property
|
168
|
+
def _default_params(self) -> Dict[str, Any]:
|
169
|
+
"""Get the default parameters for calling ChatWenxin API."""
|
170
|
+
return {
|
171
|
+
'model': self.model_name,
|
172
|
+
'temperature': self.temperature,
|
173
|
+
'top_p': self.top_p,
|
174
|
+
'max_tokens': self.max_tokens,
|
175
|
+
**self.model_kwargs,
|
176
|
+
}
|
177
|
+
|
178
|
+
def completion_with_retry(self, **kwargs: Any) -> Any:
|
179
|
+
retry_decorator = _create_retry_decorator(self)
|
180
|
+
|
181
|
+
@retry_decorator
|
182
|
+
def _completion_with_retry(**kwargs: Any) -> Any:
|
183
|
+
messages = kwargs.get('messages')
|
184
|
+
temperature = kwargs.get('temperature')
|
185
|
+
top_p = kwargs.get('top_p')
|
186
|
+
max_tokens = kwargs.get('max_tokens')
|
187
|
+
params = {
|
188
|
+
'messages': messages,
|
189
|
+
'model': self.model_name,
|
190
|
+
'top_p': top_p,
|
191
|
+
'temperature': temperature,
|
192
|
+
'max_tokens': max_tokens
|
193
|
+
}
|
194
|
+
return self.client(ChatInput.parse_obj(params),
|
195
|
+
self.verbose).dict()
|
196
|
+
|
197
|
+
return _completion_with_retry(**kwargs)
|
198
|
+
|
199
|
+
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
200
|
+
overall_token_usage: dict = {}
|
201
|
+
for output in llm_outputs:
|
202
|
+
if output is None:
|
203
|
+
# Happens in streaming
|
204
|
+
continue
|
205
|
+
token_usage = output['token_usage']
|
206
|
+
for k, v in token_usage.items():
|
207
|
+
if k in overall_token_usage:
|
208
|
+
overall_token_usage[k] += v
|
209
|
+
else:
|
210
|
+
overall_token_usage[k] = v
|
211
|
+
return {
|
212
|
+
'token_usage': overall_token_usage,
|
213
|
+
'model_name': self.model_name
|
214
|
+
}
|
215
|
+
|
216
|
+
def _generate(
|
217
|
+
self,
|
218
|
+
messages: List[BaseMessage],
|
219
|
+
stop: Optional[List[str]] = None,
|
220
|
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
221
|
+
**kwargs: Any,
|
222
|
+
) -> ChatResult:
|
223
|
+
message_dicts, params = self._create_message_dicts(messages, stop)
|
224
|
+
params = {**params, **kwargs}
|
225
|
+
|
226
|
+
response = self.completion_with_retry(messages=message_dicts, **params)
|
227
|
+
return self._create_chat_result(response)
|
228
|
+
|
229
|
+
async def _agenerate(
|
230
|
+
self,
|
231
|
+
messages: List[BaseMessage],
|
232
|
+
stop: Optional[List[str]] = None,
|
233
|
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
234
|
+
**kwargs: Any,
|
235
|
+
) -> ChatResult:
|
236
|
+
return self._generate(messages, stop, run_manager, kwargs)
|
237
|
+
|
238
|
+
def _create_message_dicts(
|
239
|
+
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
240
|
+
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
241
|
+
params = dict(self._client_params)
|
242
|
+
if stop is not None:
|
243
|
+
if 'stop' in params:
|
244
|
+
raise ValueError(
|
245
|
+
'`stop` found in both the input and default params.')
|
246
|
+
params['stop'] = stop
|
247
|
+
|
248
|
+
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
249
|
+
|
250
|
+
return message_dicts, params
|
251
|
+
|
252
|
+
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
253
|
+
generations = []
|
254
|
+
for res in response['choices']:
|
255
|
+
message = _convert_dict_to_message(res['message'])
|
256
|
+
gen = ChatGeneration(message=message)
|
257
|
+
generations.append(gen)
|
258
|
+
|
259
|
+
llm_output = {
|
260
|
+
'token_usage': response['usage'],
|
261
|
+
'model_name': self.model_name
|
262
|
+
}
|
263
|
+
return ChatResult(generations=generations, llm_output=llm_output)
|
264
|
+
|
265
|
+
@property
|
266
|
+
def _identifying_params(self) -> Mapping[str, Any]:
|
267
|
+
"""Get the identifying parameters."""
|
268
|
+
return {**{'model_name': self.model_name}, **self._default_params}
|
269
|
+
|
270
|
+
@property
|
271
|
+
def _client_params(self) -> Mapping[str, Any]:
|
272
|
+
"""Get the parameters used for the client."""
|
273
|
+
minimaxai_creds: Dict[str, Any] = {
|
274
|
+
'model': self.model_name,
|
275
|
+
}
|
276
|
+
return {**minimaxai_creds, **self._default_params}
|
277
|
+
|
278
|
+
def _get_invocation_params(self,
|
279
|
+
stop: Optional[List[str]] = None,
|
280
|
+
**kwargs: Any) -> Dict[str, Any]:
|
281
|
+
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
|
282
|
+
return {
|
283
|
+
**super()._get_invocation_params(stop=stop, **kwargs),
|
284
|
+
**self._default_params,
|
285
|
+
'model': self.model_name,
|
286
|
+
'function': kwargs.get('functions'),
|
287
|
+
}
|
288
|
+
|
289
|
+
@property
|
290
|
+
def _llm_type(self) -> str:
|
291
|
+
"""Return type of chat model."""
|
292
|
+
return 'ernie-bot-chat'
|
293
|
+
|
294
|
+
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
|
295
|
+
tiktoken_ = _import_tiktoken()
|
296
|
+
if self.tiktoken_model_name is not None:
|
297
|
+
model = self.tiktoken_model_name
|
298
|
+
else:
|
299
|
+
model = self.model_name
|
300
|
+
# model chatglm-std, chatglm-lite
|
301
|
+
# Returns the number of tokens used by a list of messages.
|
302
|
+
try:
|
303
|
+
encoding = tiktoken_.encoding_for_model(model)
|
304
|
+
except KeyError:
|
305
|
+
logger.warning(
|
306
|
+
'Warning: model not found. Using cl100k_base encoding.')
|
307
|
+
model = 'cl100k_base'
|
308
|
+
encoding = tiktoken_.get_encoding(model)
|
309
|
+
return model, encoding
|
310
|
+
|
311
|
+
def get_token_ids(self, text: str) -> List[int]:
|
312
|
+
"""Get the tokens present in the text with tiktoken package."""
|
313
|
+
# tiktoken NOT supported for Python 3.7 or below
|
314
|
+
if sys.version_info[1] <= 7:
|
315
|
+
return super().get_token_ids(text)
|
316
|
+
_, encoding_model = self._get_encoding_model()
|
317
|
+
return encoding_model.encode(text)
|
318
|
+
|
319
|
+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
320
|
+
"""Calculate num tokens for chatglm with tiktoken package.
|
321
|
+
|
322
|
+
todo: read chatglm document
|
323
|
+
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
324
|
+
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
325
|
+
if sys.version_info[1] <= 7:
|
326
|
+
return super().get_num_tokens_from_messages(messages)
|
327
|
+
model, encoding = self._get_encoding_model()
|
328
|
+
if model.startswith('chatglm'):
|
329
|
+
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
330
|
+
tokens_per_message = 4
|
331
|
+
# if there's a name, the role is omitted
|
332
|
+
tokens_per_name = -1
|
333
|
+
else:
|
334
|
+
raise NotImplementedError(
|
335
|
+
f'get_num_tokens_from_messages() is not presently implemented '
|
336
|
+
f'for model {model}.'
|
337
|
+
'See https://github.com/openai/openai-python/blob/main/chatml.md for '
|
338
|
+
'information on how messages are converted to tokens.')
|
339
|
+
num_tokens = 0
|
340
|
+
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
341
|
+
for message in messages_dict:
|
342
|
+
num_tokens += tokens_per_message
|
343
|
+
for key, value in message.items():
|
344
|
+
num_tokens += len(encoding.encode(value))
|
345
|
+
if key == 'name':
|
346
|
+
num_tokens += tokens_per_name
|
347
|
+
# every reply is primed with <im_start>assistant
|
348
|
+
num_tokens += 3
|
349
|
+
return num_tokens
|