bisheng-langchain 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/__init__.py +0 -0
- bisheng_langchain/chains/__init__.py +5 -0
- bisheng_langchain/chains/combine_documents/__init__.py +0 -0
- bisheng_langchain/chains/combine_documents/stuff.py +56 -0
- bisheng_langchain/chains/question_answering/__init__.py +240 -0
- bisheng_langchain/chains/retrieval_qa/__init__.py +0 -0
- bisheng_langchain/chains/retrieval_qa/base.py +89 -0
- bisheng_langchain/chat_models/__init__.py +11 -0
- bisheng_langchain/chat_models/host_llm.py +409 -0
- bisheng_langchain/chat_models/interface/__init__.py +10 -0
- bisheng_langchain/chat_models/interface/minimax.py +123 -0
- bisheng_langchain/chat_models/interface/openai.py +68 -0
- bisheng_langchain/chat_models/interface/types.py +61 -0
- bisheng_langchain/chat_models/interface/utils.py +5 -0
- bisheng_langchain/chat_models/interface/wenxin.py +114 -0
- bisheng_langchain/chat_models/interface/xunfei.py +233 -0
- bisheng_langchain/chat_models/interface/zhipuai.py +81 -0
- bisheng_langchain/chat_models/minimax.py +354 -0
- bisheng_langchain/chat_models/proxy_llm.py +354 -0
- bisheng_langchain/chat_models/wenxin.py +349 -0
- bisheng_langchain/chat_models/xunfeiai.py +355 -0
- bisheng_langchain/chat_models/zhipuai.py +379 -0
- bisheng_langchain/document_loaders/__init__.py +3 -0
- bisheng_langchain/document_loaders/elem_html.py +0 -0
- bisheng_langchain/document_loaders/elem_image.py +0 -0
- bisheng_langchain/document_loaders/elem_pdf.py +655 -0
- bisheng_langchain/document_loaders/parsers/__init__.py +5 -0
- bisheng_langchain/document_loaders/parsers/image.py +28 -0
- bisheng_langchain/document_loaders/parsers/test_image.py +286 -0
- bisheng_langchain/embeddings/__init__.py +7 -0
- bisheng_langchain/embeddings/host_embedding.py +133 -0
- bisheng_langchain/embeddings/interface/__init__.py +3 -0
- bisheng_langchain/embeddings/interface/types.py +23 -0
- bisheng_langchain/embeddings/interface/wenxin.py +86 -0
- bisheng_langchain/embeddings/wenxin.py +139 -0
- bisheng_langchain/vectorstores/__init__.py +3 -0
- bisheng_langchain/vectorstores/elastic_keywords_search.py +284 -0
- bisheng_langchain-0.0.1.dist-info/METADATA +64 -0
- bisheng_langchain-0.0.1.dist-info/RECORD +41 -0
- bisheng_langchain-0.0.1.dist-info/WHEEL +5 -0
- bisheng_langchain-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
# flake8: noqa
|
2
|
+
import io
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import random
|
7
|
+
import tempfile
|
8
|
+
import time
|
9
|
+
from abc import ABC
|
10
|
+
from copy import deepcopy
|
11
|
+
from pathlib import Path
|
12
|
+
from typing import Any, Iterator, List, Mapping, Optional, Union
|
13
|
+
from urllib.parse import urlparse
|
14
|
+
|
15
|
+
import cv2
|
16
|
+
import fitz
|
17
|
+
import numpy as np
|
18
|
+
import pypdfium2
|
19
|
+
import requests
|
20
|
+
from image import LayoutParser
|
21
|
+
from langchain.document_loaders.blob_loaders import Blob
|
22
|
+
|
23
|
+
|
24
|
+
def norm_rect(bbox):
|
25
|
+
x0 = np.min([bbox[0], bbox[2]])
|
26
|
+
x1 = np.max([bbox[0], bbox[2]])
|
27
|
+
y0 = np.min([bbox[1], bbox[3]])
|
28
|
+
y1 = np.max([bbox[1], bbox[3]])
|
29
|
+
return np.asarray([x0, y0, x1, y1])
|
30
|
+
|
31
|
+
|
32
|
+
def merge_rects(bboxes):
|
33
|
+
x0 = np.min(bboxes[:, 0])
|
34
|
+
y0 = np.min(bboxes[:, 1])
|
35
|
+
x1 = np.max(bboxes[:, 2])
|
36
|
+
y1 = np.max(bboxes[:, 3])
|
37
|
+
return [x0, y0, x1, y1]
|
38
|
+
|
39
|
+
|
40
|
+
def get_image_blobs(pages, pdf_reader, n, start=0):
|
41
|
+
blobs = []
|
42
|
+
for pg in range(start, start + n):
|
43
|
+
bytes_img = None
|
44
|
+
page = pages.load_page(pg)
|
45
|
+
mat = fitz.Matrix(1, 1)
|
46
|
+
try:
|
47
|
+
pm = page.get_pixmap(matrix=mat, alpha=False)
|
48
|
+
bytes_img = pm.getPNGData()
|
49
|
+
except Exception:
|
50
|
+
# some pdf input cannot get render image from fitz
|
51
|
+
page = pdf_reader.get_page(pg)
|
52
|
+
pil_image = page.render().to_pil()
|
53
|
+
img_byte_arr = io.BytesIO()
|
54
|
+
pil_image.save(img_byte_arr, format='PNG')
|
55
|
+
bytes_img = img_byte_arr.getvalue()
|
56
|
+
|
57
|
+
blobs.append(Blob(data=bytes_img))
|
58
|
+
return blobs
|
59
|
+
|
60
|
+
|
61
|
+
def test():
|
62
|
+
file_path = './data/达梦数据库招股说明书_test_v1.pdf'
|
63
|
+
blob = Blob.from_path(file_path)
|
64
|
+
pages = None
|
65
|
+
image_blobs = []
|
66
|
+
with blob.as_bytes_io() as file_path:
|
67
|
+
pages = fitz.open(file_path)
|
68
|
+
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
69
|
+
image_blobs = get_image_blobs(pages, pdf_reader)
|
70
|
+
|
71
|
+
assert len(image_blobs) == pages.page_count
|
72
|
+
layout = LayoutParser()
|
73
|
+
res = layout.parse(image_blobs[0])
|
74
|
+
|
75
|
+
|
76
|
+
def draw_polygon(image, bbox, text=None, color=(255, 0, 0), thickness=1):
|
77
|
+
bbox = bbox.astype(np.int32)
|
78
|
+
is_rect = bbox.shape[0] == 4
|
79
|
+
if is_rect:
|
80
|
+
start_point = (bbox[0], bbox[1])
|
81
|
+
end_point = (bbox[2], bbox[3])
|
82
|
+
image = cv2.rectangle(image, start_point, end_point, color, thickness)
|
83
|
+
else:
|
84
|
+
polys = [bbox.astype(np.int32).reshape((-1, 1, 2))]
|
85
|
+
cv2.polylines(image, polys, True, color=color, thickness=thickness)
|
86
|
+
start_point = (polys[0][0, 0, 0], polys[0][0, 0, 1])
|
87
|
+
|
88
|
+
if text:
|
89
|
+
fontFace = cv2.FONT_HERSHEY_SIMPLEX
|
90
|
+
fontScale = 0.5
|
91
|
+
color = (0, 0, 255)
|
92
|
+
image = cv2.putText(image, text, start_point, fontFace, fontScale,
|
93
|
+
color, 1)
|
94
|
+
|
95
|
+
return image
|
96
|
+
|
97
|
+
|
98
|
+
def test_vis():
|
99
|
+
# file_path = './data/达梦数据库招股说明书_test_v1.pdf'
|
100
|
+
file_path = './data/pdf_input/《中国药典》2020年版 一部.pdf'
|
101
|
+
output_prefix = 'zhongguoyaodian_2020_v1'
|
102
|
+
start, end, n = 70, 80, 10
|
103
|
+
blob = Blob.from_path(file_path)
|
104
|
+
pages = None
|
105
|
+
image_blobs = []
|
106
|
+
with blob.as_bytes_io() as file_path:
|
107
|
+
pages = fitz.open(file_path)
|
108
|
+
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
109
|
+
image_blobs = get_image_blobs(pages, pdf_reader, n, start)
|
110
|
+
|
111
|
+
assert len(image_blobs) == n
|
112
|
+
|
113
|
+
for i, blob in enumerate(image_blobs):
|
114
|
+
idx = i + start
|
115
|
+
# blob = image_blobs[2]
|
116
|
+
layout = LayoutParser()
|
117
|
+
out = layout.parse(blob)
|
118
|
+
res = json.loads(out[0].page_content)
|
119
|
+
bboxes = []
|
120
|
+
labels = []
|
121
|
+
for r in res:
|
122
|
+
bboxes.append(r['bbox'])
|
123
|
+
labels.append(str(r['category_id']))
|
124
|
+
|
125
|
+
bboxes = np.asarray(bboxes)
|
126
|
+
|
127
|
+
bytes_arr = np.frombuffer(blob.as_bytes(), dtype=np.uint8)
|
128
|
+
image = cv2.imdecode(bytes_arr, flags=1)
|
129
|
+
for bbox, text in zip(bboxes, labels):
|
130
|
+
image = draw_polygon(image, bbox, text)
|
131
|
+
|
132
|
+
outf = f'./data/{output_prefix}_layout_p{idx+1}_vis.png'
|
133
|
+
cv2.imwrite(outf, image)
|
134
|
+
|
135
|
+
|
136
|
+
def order_by_tbyx(block_info, th=10):
|
137
|
+
"""
|
138
|
+
block_info: [(b0, b1, b2, b3, text, x, y)+]
|
139
|
+
th: threshold of the position threshold
|
140
|
+
"""
|
141
|
+
# sort using y1 first and then x1
|
142
|
+
res = sorted(block_info, key=lambda b: (b[1], b[0]))
|
143
|
+
for i in range(len(res) - 1):
|
144
|
+
for j in range(i, 0, -1):
|
145
|
+
# restore the order using the
|
146
|
+
if (abs(res[j + 1][1] - res[j][1]) < th
|
147
|
+
and (res[j + 1][0] < res[j][0])):
|
148
|
+
tmp = deepcopy(res[j])
|
149
|
+
res[j] = deepcopy(res[j + 1])
|
150
|
+
res[j + 1] = deepcopy(tmp)
|
151
|
+
else:
|
152
|
+
break
|
153
|
+
return res
|
154
|
+
|
155
|
+
|
156
|
+
def test_vis2():
|
157
|
+
# file_path = './data/达梦数据库招股说明书_test_v1.pdf'
|
158
|
+
file_path = './data/pdf_input/达梦数据库招股说明书.pdf'
|
159
|
+
output_prefix = 'dameng_pageblock'
|
160
|
+
|
161
|
+
start = 0
|
162
|
+
end = 10
|
163
|
+
n = end - start
|
164
|
+
blob = Blob.from_path(file_path)
|
165
|
+
pages = None
|
166
|
+
image_blobs = []
|
167
|
+
with blob.as_bytes_io() as file_path:
|
168
|
+
pages = fitz.open(file_path)
|
169
|
+
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
170
|
+
image_blobs = get_image_blobs(pages, pdf_reader, n, start)
|
171
|
+
|
172
|
+
assert len(image_blobs) == pages.page_count
|
173
|
+
|
174
|
+
for i, blob in enumerate(image_blobs):
|
175
|
+
idx = i + start
|
176
|
+
page = pages.load_page(idx)
|
177
|
+
|
178
|
+
rect = page.rect
|
179
|
+
print('rect', rect)
|
180
|
+
o = 10
|
181
|
+
b0 = np.asarray([rect.x0 + o, rect.y0 + o, rect.x1 - o, rect.y1 - o])
|
182
|
+
|
183
|
+
bytes_arr = np.frombuffer(blob.as_bytes(), dtype=np.uint8)
|
184
|
+
image = cv2.imdecode(bytes_arr, flags=1)
|
185
|
+
|
186
|
+
image = draw_polygon(image, b0, '0.0')
|
187
|
+
|
188
|
+
textpage = page.get_textpage()
|
189
|
+
blocks = textpage.extractBLOCKS()
|
190
|
+
IMG_BLOCK_TYPE = 1
|
191
|
+
|
192
|
+
# blocks = order_by_tbyx(blocks)
|
193
|
+
bboxes = []
|
194
|
+
for off, b in enumerate(blocks):
|
195
|
+
label = 'text' if b[-1] != IMG_BLOCK_TYPE else 'image'
|
196
|
+
label = f'{label}-{off}'
|
197
|
+
print('block', b, label)
|
198
|
+
bbox = np.asarray([b[0], b[1], b[2], b[3]])
|
199
|
+
bboxes.append(bbox)
|
200
|
+
|
201
|
+
image = draw_polygon(image, bbox, label)
|
202
|
+
|
203
|
+
if bboxes:
|
204
|
+
b1 = merge_rects(np.asarray(bboxes))
|
205
|
+
b1 = np.asarray(b1)
|
206
|
+
image = draw_polygon(image, b1, '0.1')
|
207
|
+
|
208
|
+
outf = f'./data/{output_prefix}_p{idx}_vis.png'
|
209
|
+
cv2.imwrite(outf, image)
|
210
|
+
|
211
|
+
|
212
|
+
def test_vis3():
|
213
|
+
file_path = './data/pdf_input/《中国药典》2020年版 一部.pdf'
|
214
|
+
|
215
|
+
start = 50
|
216
|
+
end = 60
|
217
|
+
n = end - start
|
218
|
+
output_prefix = 'zhongguoyaodian_2020_v1'
|
219
|
+
|
220
|
+
blob = Blob.from_path(file_path)
|
221
|
+
pages = None
|
222
|
+
image_blobs = []
|
223
|
+
with blob.as_bytes_io() as file_path:
|
224
|
+
pages = fitz.open(file_path)
|
225
|
+
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
226
|
+
image_blobs = get_image_blobs(pages, pdf_reader, n, start=50)
|
227
|
+
|
228
|
+
assert len(image_blobs) == n
|
229
|
+
|
230
|
+
for i, blob in enumerate(image_blobs):
|
231
|
+
idx = i + start
|
232
|
+
page = pages.load_page(idx)
|
233
|
+
|
234
|
+
rect = page.rect
|
235
|
+
print('rect', rect)
|
236
|
+
o = 10
|
237
|
+
b0 = np.asarray([rect.x0 + o, rect.y0 + o, rect.x1 - o, rect.y1 - o])
|
238
|
+
|
239
|
+
bytes_arr = np.frombuffer(blob.as_bytes(), dtype=np.uint8)
|
240
|
+
image = cv2.imdecode(bytes_arr, flags=1)
|
241
|
+
|
242
|
+
image = draw_polygon(image, b0, '0.0')
|
243
|
+
|
244
|
+
rotation_matrix = np.asarray(page.rotation_matrix).reshape((3, 2))
|
245
|
+
c1 = (rotation_matrix[0, 0] - 1) <= 1e-6
|
246
|
+
c2 = (rotation_matrix[1, 1] - 1) <= 1e-6
|
247
|
+
is_rotated = c1 and c2
|
248
|
+
|
249
|
+
textpage = page.get_textpage()
|
250
|
+
blocks = textpage.extractBLOCKS()
|
251
|
+
IMG_BLOCK_TYPE = 1
|
252
|
+
|
253
|
+
# blocks = order_by_tbyx(blocks)
|
254
|
+
bboxes = []
|
255
|
+
for off, b in enumerate(blocks):
|
256
|
+
label = 'text' if b[-1] != IMG_BLOCK_TYPE else 'image'
|
257
|
+
label = f'{label}-{off}'
|
258
|
+
print('block', b, label)
|
259
|
+
bbox = np.asarray([b[0], b[1], b[2], b[3]])
|
260
|
+
|
261
|
+
aug_bbox = bbox.reshape((-1, 2))
|
262
|
+
padding = np.ones((len(aug_bbox), 1))
|
263
|
+
aug_bbox = np.hstack([aug_bbox, padding])
|
264
|
+
new_bbox = np.dot(aug_bbox, rotation_matrix).reshape(-1)
|
265
|
+
|
266
|
+
new_bbox = norm_rect(new_bbox)
|
267
|
+
|
268
|
+
print('new_bboxes', new_bbox)
|
269
|
+
bboxes.append(new_bbox)
|
270
|
+
|
271
|
+
image = draw_polygon(image, new_bbox, label)
|
272
|
+
|
273
|
+
print(bboxes)
|
274
|
+
if bboxes:
|
275
|
+
b1 = merge_rects(np.asarray(bboxes))
|
276
|
+
b1 = np.asarray(b1)
|
277
|
+
image = draw_polygon(image, b1, '0.1')
|
278
|
+
|
279
|
+
outf = f'./data/{output_prefix}_p{idx}_vis.png'
|
280
|
+
cv2.imwrite(outf, image)
|
281
|
+
|
282
|
+
|
283
|
+
# test_vis3()
|
284
|
+
# test_vis2()
|
285
|
+
test_vis()
|
286
|
+
# test()
|
@@ -0,0 +1,133 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
import requests
|
7
|
+
from langchain.embeddings.base import Embeddings
|
8
|
+
from langchain.utils import get_from_dict_or_env
|
9
|
+
from pydantic import BaseModel, Extra, Field, root_validator
|
10
|
+
from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_attempt,
|
11
|
+
wait_exponential)
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
def _create_retry_decorator(
|
17
|
+
embeddings: HostEmbeddings) -> Callable[[Any], Any]:
|
18
|
+
min_seconds = 4
|
19
|
+
max_seconds = 10
|
20
|
+
# Wait 2^x * 1 second between each retry starting with
|
21
|
+
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
22
|
+
return retry(
|
23
|
+
reraise=True,
|
24
|
+
stop=stop_after_attempt(embeddings.max_retries),
|
25
|
+
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
26
|
+
retry=(retry_if_exception_type(Exception)),
|
27
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
def embed_with_retry(embeddings: HostEmbeddings, **kwargs: Any) -> Any:
|
32
|
+
"""Use tenacity to retry the embedding call."""
|
33
|
+
retry_decorator = _create_retry_decorator(embeddings)
|
34
|
+
|
35
|
+
@retry_decorator
|
36
|
+
def _embed_with_retry(**kwargs: Any) -> Any:
|
37
|
+
return embeddings.embed(**kwargs)
|
38
|
+
|
39
|
+
return _embed_with_retry(**kwargs)
|
40
|
+
|
41
|
+
|
42
|
+
class HostEmbeddings(BaseModel, Embeddings):
|
43
|
+
"""host embedding models.
|
44
|
+
"""
|
45
|
+
|
46
|
+
client: Optional[Any] #: :meta private:
|
47
|
+
"""Model name to use."""
|
48
|
+
model: str = 'embedding-host'
|
49
|
+
host_base_url: str = None
|
50
|
+
|
51
|
+
deployment: Optional[str] = 'default'
|
52
|
+
|
53
|
+
embedding_ctx_length: Optional[int] = 6144
|
54
|
+
"""The maximum number of tokens to embed at once."""
|
55
|
+
"""Maximum number of texts to embed in each batch"""
|
56
|
+
max_retries: Optional[int] = 6
|
57
|
+
"""Maximum number of retries to make when generating."""
|
58
|
+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
59
|
+
"""Timeout in seconds for the OpenAPI request."""
|
60
|
+
|
61
|
+
model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
62
|
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
63
|
+
|
64
|
+
verbose: Optional[bool] = False
|
65
|
+
|
66
|
+
class Config:
|
67
|
+
"""Configuration for this pydantic object."""
|
68
|
+
|
69
|
+
extra = Extra.forbid
|
70
|
+
|
71
|
+
@root_validator()
|
72
|
+
def validate_environment(cls, values: Dict) -> Dict:
|
73
|
+
"""Validate that api key and python package exists in environment."""
|
74
|
+
values['host_base_url'] = get_from_dict_or_env(values, 'host_base_url',
|
75
|
+
'HostBaseUrl')
|
76
|
+
|
77
|
+
try:
|
78
|
+
values['client'] = requests.post
|
79
|
+
except AttributeError:
|
80
|
+
raise ValueError(
|
81
|
+
'Try upgrading it with `pip install --upgrade requests`.')
|
82
|
+
return values
|
83
|
+
|
84
|
+
@property
|
85
|
+
def _invocation_params(self) -> Dict:
|
86
|
+
api_args = {
|
87
|
+
'model': self.model,
|
88
|
+
'request_timeout': self.request_timeout,
|
89
|
+
**self.model_kwargs,
|
90
|
+
}
|
91
|
+
return api_args
|
92
|
+
|
93
|
+
def embed(self, texts: List[str], **kwargs) -> List[List[float]]:
|
94
|
+
emb_type = kwargs.get('type', 'raw')
|
95
|
+
inp = {'texts': texts, 'model': self.model, 'type': emb_type}
|
96
|
+
if self.verbose:
|
97
|
+
print('payload', inp)
|
98
|
+
|
99
|
+
url = f'{self.host_base_url}/{self.model}/infer'
|
100
|
+
outp = self.client(url=url, json=inp).json()
|
101
|
+
if outp['status_code'] != 200:
|
102
|
+
raise ValueError(
|
103
|
+
f"API returned an error: {outp['status_message']}")
|
104
|
+
return outp['embeddings']
|
105
|
+
|
106
|
+
def embed_documents(self,
|
107
|
+
texts: List[str],
|
108
|
+
chunk_size: Optional[int] = 0) -> List[List[float]]:
|
109
|
+
if not texts:
|
110
|
+
return []
|
111
|
+
"""Embed search docs."""
|
112
|
+
texts = [text for text in texts if text]
|
113
|
+
embeddings = embed_with_retry(self, texts=texts, type='doc')
|
114
|
+
return embeddings
|
115
|
+
|
116
|
+
def embed_query(self, text: str) -> List[float]:
|
117
|
+
embeddings = embed_with_retry(self, texts=[text], type='query')
|
118
|
+
return embeddings[0]
|
119
|
+
|
120
|
+
|
121
|
+
class ME5Embedding(HostEmbeddings):
|
122
|
+
model: str = 'multi-e5'
|
123
|
+
embedding_ctx_length: int = 512
|
124
|
+
|
125
|
+
|
126
|
+
class BGEZhEmbedding(HostEmbeddings):
|
127
|
+
model: str = 'bge-zh'
|
128
|
+
embedding_ctx_length: int = 512
|
129
|
+
|
130
|
+
|
131
|
+
class GTEEmbedding(HostEmbeddings):
|
132
|
+
model: str = 'gte'
|
133
|
+
embedding_ctx_length: int = 512
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from typing import Any, Dict, List, Union
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
|
6
|
+
class EmbeddingInput(BaseModel):
|
7
|
+
model: str
|
8
|
+
input: Union[str, List[str]]
|
9
|
+
|
10
|
+
|
11
|
+
class Embedding(BaseModel):
|
12
|
+
object: str = 'embedding'
|
13
|
+
embedding: List[float]
|
14
|
+
index: int
|
15
|
+
|
16
|
+
|
17
|
+
class EmbeddingOutput(BaseModel):
|
18
|
+
status_code: int
|
19
|
+
status_message: str = 'success'
|
20
|
+
object: str = None
|
21
|
+
data: List[Embedding] = []
|
22
|
+
model: str = None
|
23
|
+
usage: Dict[str, Any] = None
|
@@ -0,0 +1,86 @@
|
|
1
|
+
import json
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import requests
|
5
|
+
from requests.exceptions import HTTPError
|
6
|
+
|
7
|
+
|
8
|
+
def get_access_token(api_key, sec_key):
|
9
|
+
url = (f'https://aip.baidubce.com/oauth/2.0/token?'
|
10
|
+
f'grant_type=client_credentials'
|
11
|
+
f'&client_id={api_key}&client_secret={sec_key}')
|
12
|
+
|
13
|
+
payload = json.dumps('')
|
14
|
+
headers = {
|
15
|
+
'Content-Type': 'application/json',
|
16
|
+
'Accept': 'application/json'
|
17
|
+
}
|
18
|
+
|
19
|
+
response = requests.request('POST', url, headers=headers, data=payload)
|
20
|
+
return response.json().get('access_token')
|
21
|
+
|
22
|
+
|
23
|
+
class EmbeddingClient(object):
|
24
|
+
|
25
|
+
def __init__(self, api_key, sec_key, **kwargs):
|
26
|
+
self.api_key = api_key
|
27
|
+
self.sec_key = sec_key
|
28
|
+
self.ep_url = ('https://aip.baidubce.com/rpc/2.0/ai_custom/v1/'
|
29
|
+
'wenxinworkshop/embeddings')
|
30
|
+
self.headers = {'Content-Type': 'application/json'}
|
31
|
+
self.max_text_tokens = 384
|
32
|
+
self.max_text_num = 16
|
33
|
+
self.drop_exceed_token = kwargs.get('drop_exceed_token', True)
|
34
|
+
|
35
|
+
def create(self, model, input, verbose=False, **kwargs):
|
36
|
+
texts = input
|
37
|
+
if isinstance(texts, str):
|
38
|
+
texts = [texts]
|
39
|
+
|
40
|
+
if self.drop_exceed_token:
|
41
|
+
texts = [t[:self.max_text_tokens] for t in texts]
|
42
|
+
|
43
|
+
cond = np.all([len(text) <= self.max_text_tokens for text in texts])
|
44
|
+
if not cond:
|
45
|
+
raise HTTPError('text exceed max token size 384')
|
46
|
+
|
47
|
+
token = get_access_token(self.api_key, self.sec_key)
|
48
|
+
endpoint = f'{self.ep_url}/{model}?access_token={token}'
|
49
|
+
|
50
|
+
def _call(sub_texts):
|
51
|
+
payload = json.dumps({'input': sub_texts})
|
52
|
+
response = requests.post(endpoint,
|
53
|
+
headers=self.headers,
|
54
|
+
data=payload)
|
55
|
+
status_message = 'success'
|
56
|
+
status_code = response.status_code
|
57
|
+
usage = {'prompt_tokens': 0, 'total_tokens': 0}
|
58
|
+
data = []
|
59
|
+
if status_code == 200:
|
60
|
+
try:
|
61
|
+
info = json.loads(response.text)
|
62
|
+
status_code = info.get('error_code', 200)
|
63
|
+
status_message = info.get('error_msg', status_message)
|
64
|
+
if status_code == 200:
|
65
|
+
data = info['data']
|
66
|
+
usage = info['usage']
|
67
|
+
else:
|
68
|
+
raise HTTPError(status_message)
|
69
|
+
except Exception as e:
|
70
|
+
raise HTTPError(str(e))
|
71
|
+
else:
|
72
|
+
raise HTTPError('requests error')
|
73
|
+
return data, usage
|
74
|
+
|
75
|
+
data = []
|
76
|
+
usage = {'prompt_tokens': 0, 'total_tokens': 0}
|
77
|
+
|
78
|
+
for i in range(0, len(texts), self.max_text_num):
|
79
|
+
sub_texts = texts[i:(i + self.max_text_num)]
|
80
|
+
sub_data, sub_usage = _call(sub_texts)
|
81
|
+
data.extend(sub_data)
|
82
|
+
usage['prompt_tokens'] += sub_usage['prompt_tokens']
|
83
|
+
usage['total_tokens'] += sub_usage['total_tokens']
|
84
|
+
|
85
|
+
outp = dict(status_code=200, model=model, data=data, usage=usage)
|
86
|
+
return outp
|
@@ -0,0 +1,139 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
# import warnings
|
5
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
6
|
+
|
7
|
+
# import numpy as np
|
8
|
+
from langchain.embeddings.base import Embeddings
|
9
|
+
from langchain.utils import get_from_dict_or_env
|
10
|
+
from pydantic import BaseModel, Extra, Field, root_validator
|
11
|
+
from requests.exceptions import HTTPError
|
12
|
+
from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_attempt,
|
13
|
+
wait_exponential)
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def _create_retry_decorator(
|
19
|
+
embeddings: WenxinEmbeddings) -> Callable[[Any], Any]:
|
20
|
+
min_seconds = 4
|
21
|
+
max_seconds = 10
|
22
|
+
# Wait 2^x * 1 second between each retry starting with
|
23
|
+
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
24
|
+
return retry(
|
25
|
+
reraise=True,
|
26
|
+
stop=stop_after_attempt(embeddings.max_retries),
|
27
|
+
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
28
|
+
retry=(retry_if_exception_type(HTTPError)),
|
29
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def embed_with_retry(embeddings: WenxinEmbeddings, **kwargs: Any) -> Any:
|
34
|
+
"""Use tenacity to retry the embedding call."""
|
35
|
+
retry_decorator = _create_retry_decorator(embeddings)
|
36
|
+
|
37
|
+
@retry_decorator
|
38
|
+
def _embed_with_retry(**kwargs: Any) -> Any:
|
39
|
+
return embeddings.embed(**kwargs)
|
40
|
+
|
41
|
+
return _embed_with_retry(**kwargs)
|
42
|
+
|
43
|
+
|
44
|
+
class WenxinEmbeddings(BaseModel, Embeddings):
|
45
|
+
"""Wenxin embedding models.
|
46
|
+
|
47
|
+
To use, the environment variable ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY``
|
48
|
+
set with your API key or pass it as a named parameter to the constructor.
|
49
|
+
|
50
|
+
Example:
|
51
|
+
.. code-block:: python
|
52
|
+
from bisheng_langchain.embeddings import WenxinEmbeddings
|
53
|
+
wenxin_embeddings = WenxinEmbeddings(
|
54
|
+
wenxin_api_key="my-api-key",
|
55
|
+
wenxin_secret_key='xxx')
|
56
|
+
|
57
|
+
"""
|
58
|
+
|
59
|
+
client: Optional[Any] #: :meta private:
|
60
|
+
model: str = 'embedding-v1'
|
61
|
+
|
62
|
+
deployment: Optional[str] = 'default'
|
63
|
+
wenxin_api_key: Optional[str] = None
|
64
|
+
wenxin_secret_key: Optional[str] = None
|
65
|
+
|
66
|
+
embedding_ctx_length: Optional[int] = 6144
|
67
|
+
"""The maximum number of tokens to embed at once."""
|
68
|
+
"""Maximum number of texts to embed in each batch"""
|
69
|
+
max_retries: Optional[int] = 6
|
70
|
+
"""Maximum number of retries to make when generating."""
|
71
|
+
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
72
|
+
"""Timeout in seconds for the OpenAPI request."""
|
73
|
+
|
74
|
+
model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
75
|
+
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
76
|
+
|
77
|
+
class Config:
|
78
|
+
"""Configuration for this pydantic object."""
|
79
|
+
|
80
|
+
extra = Extra.forbid
|
81
|
+
|
82
|
+
@root_validator()
|
83
|
+
def validate_environment(cls, values: Dict) -> Dict:
|
84
|
+
"""Validate that api key and python package exists in environment."""
|
85
|
+
values['wenxin_api_key'] = get_from_dict_or_env(
|
86
|
+
values, 'wenxin_api_key', 'WENXIN_API_KEY')
|
87
|
+
values['wenxin_secret_key'] = get_from_dict_or_env(
|
88
|
+
values,
|
89
|
+
'wenxin_secret_key',
|
90
|
+
'WENXIN_SECRET_KEY',
|
91
|
+
)
|
92
|
+
|
93
|
+
api_key = values['wenxin_api_key']
|
94
|
+
sec_key = values['wenxin_secret_key']
|
95
|
+
try:
|
96
|
+
from .interface import WenxinEmbeddingClient
|
97
|
+
values['client'] = WenxinEmbeddingClient(api_key=api_key,
|
98
|
+
sec_key=sec_key)
|
99
|
+
except AttributeError:
|
100
|
+
raise ValueError(
|
101
|
+
'Try upgrading it with `pip install --upgrade requests`.')
|
102
|
+
return values
|
103
|
+
|
104
|
+
@property
|
105
|
+
def _invocation_params(self) -> Dict:
|
106
|
+
wenxin_args = {
|
107
|
+
'model': self.model,
|
108
|
+
'request_timeout': self.request_timeout,
|
109
|
+
**self.model_kwargs,
|
110
|
+
}
|
111
|
+
|
112
|
+
return wenxin_args
|
113
|
+
|
114
|
+
def embed(self, texts: List[str]) -> List[List[float]]:
|
115
|
+
inp = {'input': texts, 'model': self.model}
|
116
|
+
outp = self.client.create(**inp)
|
117
|
+
if outp['status_code'] != 200:
|
118
|
+
raise ValueError(
|
119
|
+
f"Wenxin API returned an error: {outp['status_message']}")
|
120
|
+
return [e['embedding'] for e in outp['data']]
|
121
|
+
|
122
|
+
def embed_documents(self,
|
123
|
+
texts: List[str],
|
124
|
+
chunk_size: Optional[int] = 0) -> List[List[float]]:
|
125
|
+
embeddings = embed_with_retry(self, texts=texts)
|
126
|
+
return embeddings
|
127
|
+
|
128
|
+
def embed_query(self, text: str) -> List[float]:
|
129
|
+
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
text: The text to embed.
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
Embedding for the text.
|
136
|
+
"""
|
137
|
+
|
138
|
+
embeddings = embed_with_retry(self, texts=[text])
|
139
|
+
return embeddings[0]
|