bisheng-langchain 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. bisheng_langchain/__init__.py +0 -0
  2. bisheng_langchain/chains/__init__.py +5 -0
  3. bisheng_langchain/chains/combine_documents/__init__.py +0 -0
  4. bisheng_langchain/chains/combine_documents/stuff.py +56 -0
  5. bisheng_langchain/chains/question_answering/__init__.py +240 -0
  6. bisheng_langchain/chains/retrieval_qa/__init__.py +0 -0
  7. bisheng_langchain/chains/retrieval_qa/base.py +89 -0
  8. bisheng_langchain/chat_models/__init__.py +11 -0
  9. bisheng_langchain/chat_models/host_llm.py +409 -0
  10. bisheng_langchain/chat_models/interface/__init__.py +10 -0
  11. bisheng_langchain/chat_models/interface/minimax.py +123 -0
  12. bisheng_langchain/chat_models/interface/openai.py +68 -0
  13. bisheng_langchain/chat_models/interface/types.py +61 -0
  14. bisheng_langchain/chat_models/interface/utils.py +5 -0
  15. bisheng_langchain/chat_models/interface/wenxin.py +114 -0
  16. bisheng_langchain/chat_models/interface/xunfei.py +233 -0
  17. bisheng_langchain/chat_models/interface/zhipuai.py +81 -0
  18. bisheng_langchain/chat_models/minimax.py +354 -0
  19. bisheng_langchain/chat_models/proxy_llm.py +354 -0
  20. bisheng_langchain/chat_models/wenxin.py +349 -0
  21. bisheng_langchain/chat_models/xunfeiai.py +355 -0
  22. bisheng_langchain/chat_models/zhipuai.py +379 -0
  23. bisheng_langchain/document_loaders/__init__.py +3 -0
  24. bisheng_langchain/document_loaders/elem_html.py +0 -0
  25. bisheng_langchain/document_loaders/elem_image.py +0 -0
  26. bisheng_langchain/document_loaders/elem_pdf.py +655 -0
  27. bisheng_langchain/document_loaders/parsers/__init__.py +5 -0
  28. bisheng_langchain/document_loaders/parsers/image.py +28 -0
  29. bisheng_langchain/document_loaders/parsers/test_image.py +286 -0
  30. bisheng_langchain/embeddings/__init__.py +7 -0
  31. bisheng_langchain/embeddings/host_embedding.py +133 -0
  32. bisheng_langchain/embeddings/interface/__init__.py +3 -0
  33. bisheng_langchain/embeddings/interface/types.py +23 -0
  34. bisheng_langchain/embeddings/interface/wenxin.py +86 -0
  35. bisheng_langchain/embeddings/wenxin.py +139 -0
  36. bisheng_langchain/vectorstores/__init__.py +3 -0
  37. bisheng_langchain/vectorstores/elastic_keywords_search.py +284 -0
  38. bisheng_langchain-0.0.1.dist-info/METADATA +64 -0
  39. bisheng_langchain-0.0.1.dist-info/RECORD +41 -0
  40. bisheng_langchain-0.0.1.dist-info/WHEEL +5 -0
  41. bisheng_langchain-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,286 @@
1
+ # flake8: noqa
2
+ import io
3
+ import json
4
+ import logging
5
+ import os
6
+ import random
7
+ import tempfile
8
+ import time
9
+ from abc import ABC
10
+ from copy import deepcopy
11
+ from pathlib import Path
12
+ from typing import Any, Iterator, List, Mapping, Optional, Union
13
+ from urllib.parse import urlparse
14
+
15
+ import cv2
16
+ import fitz
17
+ import numpy as np
18
+ import pypdfium2
19
+ import requests
20
+ from image import LayoutParser
21
+ from langchain.document_loaders.blob_loaders import Blob
22
+
23
+
24
+ def norm_rect(bbox):
25
+ x0 = np.min([bbox[0], bbox[2]])
26
+ x1 = np.max([bbox[0], bbox[2]])
27
+ y0 = np.min([bbox[1], bbox[3]])
28
+ y1 = np.max([bbox[1], bbox[3]])
29
+ return np.asarray([x0, y0, x1, y1])
30
+
31
+
32
+ def merge_rects(bboxes):
33
+ x0 = np.min(bboxes[:, 0])
34
+ y0 = np.min(bboxes[:, 1])
35
+ x1 = np.max(bboxes[:, 2])
36
+ y1 = np.max(bboxes[:, 3])
37
+ return [x0, y0, x1, y1]
38
+
39
+
40
+ def get_image_blobs(pages, pdf_reader, n, start=0):
41
+ blobs = []
42
+ for pg in range(start, start + n):
43
+ bytes_img = None
44
+ page = pages.load_page(pg)
45
+ mat = fitz.Matrix(1, 1)
46
+ try:
47
+ pm = page.get_pixmap(matrix=mat, alpha=False)
48
+ bytes_img = pm.getPNGData()
49
+ except Exception:
50
+ # some pdf input cannot get render image from fitz
51
+ page = pdf_reader.get_page(pg)
52
+ pil_image = page.render().to_pil()
53
+ img_byte_arr = io.BytesIO()
54
+ pil_image.save(img_byte_arr, format='PNG')
55
+ bytes_img = img_byte_arr.getvalue()
56
+
57
+ blobs.append(Blob(data=bytes_img))
58
+ return blobs
59
+
60
+
61
+ def test():
62
+ file_path = './data/达梦数据库招股说明书_test_v1.pdf'
63
+ blob = Blob.from_path(file_path)
64
+ pages = None
65
+ image_blobs = []
66
+ with blob.as_bytes_io() as file_path:
67
+ pages = fitz.open(file_path)
68
+ pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
69
+ image_blobs = get_image_blobs(pages, pdf_reader)
70
+
71
+ assert len(image_blobs) == pages.page_count
72
+ layout = LayoutParser()
73
+ res = layout.parse(image_blobs[0])
74
+
75
+
76
+ def draw_polygon(image, bbox, text=None, color=(255, 0, 0), thickness=1):
77
+ bbox = bbox.astype(np.int32)
78
+ is_rect = bbox.shape[0] == 4
79
+ if is_rect:
80
+ start_point = (bbox[0], bbox[1])
81
+ end_point = (bbox[2], bbox[3])
82
+ image = cv2.rectangle(image, start_point, end_point, color, thickness)
83
+ else:
84
+ polys = [bbox.astype(np.int32).reshape((-1, 1, 2))]
85
+ cv2.polylines(image, polys, True, color=color, thickness=thickness)
86
+ start_point = (polys[0][0, 0, 0], polys[0][0, 0, 1])
87
+
88
+ if text:
89
+ fontFace = cv2.FONT_HERSHEY_SIMPLEX
90
+ fontScale = 0.5
91
+ color = (0, 0, 255)
92
+ image = cv2.putText(image, text, start_point, fontFace, fontScale,
93
+ color, 1)
94
+
95
+ return image
96
+
97
+
98
+ def test_vis():
99
+ # file_path = './data/达梦数据库招股说明书_test_v1.pdf'
100
+ file_path = './data/pdf_input/《中国药典》2020年版 一部.pdf'
101
+ output_prefix = 'zhongguoyaodian_2020_v1'
102
+ start, end, n = 70, 80, 10
103
+ blob = Blob.from_path(file_path)
104
+ pages = None
105
+ image_blobs = []
106
+ with blob.as_bytes_io() as file_path:
107
+ pages = fitz.open(file_path)
108
+ pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
109
+ image_blobs = get_image_blobs(pages, pdf_reader, n, start)
110
+
111
+ assert len(image_blobs) == n
112
+
113
+ for i, blob in enumerate(image_blobs):
114
+ idx = i + start
115
+ # blob = image_blobs[2]
116
+ layout = LayoutParser()
117
+ out = layout.parse(blob)
118
+ res = json.loads(out[0].page_content)
119
+ bboxes = []
120
+ labels = []
121
+ for r in res:
122
+ bboxes.append(r['bbox'])
123
+ labels.append(str(r['category_id']))
124
+
125
+ bboxes = np.asarray(bboxes)
126
+
127
+ bytes_arr = np.frombuffer(blob.as_bytes(), dtype=np.uint8)
128
+ image = cv2.imdecode(bytes_arr, flags=1)
129
+ for bbox, text in zip(bboxes, labels):
130
+ image = draw_polygon(image, bbox, text)
131
+
132
+ outf = f'./data/{output_prefix}_layout_p{idx+1}_vis.png'
133
+ cv2.imwrite(outf, image)
134
+
135
+
136
+ def order_by_tbyx(block_info, th=10):
137
+ """
138
+ block_info: [(b0, b1, b2, b3, text, x, y)+]
139
+ th: threshold of the position threshold
140
+ """
141
+ # sort using y1 first and then x1
142
+ res = sorted(block_info, key=lambda b: (b[1], b[0]))
143
+ for i in range(len(res) - 1):
144
+ for j in range(i, 0, -1):
145
+ # restore the order using the
146
+ if (abs(res[j + 1][1] - res[j][1]) < th
147
+ and (res[j + 1][0] < res[j][0])):
148
+ tmp = deepcopy(res[j])
149
+ res[j] = deepcopy(res[j + 1])
150
+ res[j + 1] = deepcopy(tmp)
151
+ else:
152
+ break
153
+ return res
154
+
155
+
156
+ def test_vis2():
157
+ # file_path = './data/达梦数据库招股说明书_test_v1.pdf'
158
+ file_path = './data/pdf_input/达梦数据库招股说明书.pdf'
159
+ output_prefix = 'dameng_pageblock'
160
+
161
+ start = 0
162
+ end = 10
163
+ n = end - start
164
+ blob = Blob.from_path(file_path)
165
+ pages = None
166
+ image_blobs = []
167
+ with blob.as_bytes_io() as file_path:
168
+ pages = fitz.open(file_path)
169
+ pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
170
+ image_blobs = get_image_blobs(pages, pdf_reader, n, start)
171
+
172
+ assert len(image_blobs) == pages.page_count
173
+
174
+ for i, blob in enumerate(image_blobs):
175
+ idx = i + start
176
+ page = pages.load_page(idx)
177
+
178
+ rect = page.rect
179
+ print('rect', rect)
180
+ o = 10
181
+ b0 = np.asarray([rect.x0 + o, rect.y0 + o, rect.x1 - o, rect.y1 - o])
182
+
183
+ bytes_arr = np.frombuffer(blob.as_bytes(), dtype=np.uint8)
184
+ image = cv2.imdecode(bytes_arr, flags=1)
185
+
186
+ image = draw_polygon(image, b0, '0.0')
187
+
188
+ textpage = page.get_textpage()
189
+ blocks = textpage.extractBLOCKS()
190
+ IMG_BLOCK_TYPE = 1
191
+
192
+ # blocks = order_by_tbyx(blocks)
193
+ bboxes = []
194
+ for off, b in enumerate(blocks):
195
+ label = 'text' if b[-1] != IMG_BLOCK_TYPE else 'image'
196
+ label = f'{label}-{off}'
197
+ print('block', b, label)
198
+ bbox = np.asarray([b[0], b[1], b[2], b[3]])
199
+ bboxes.append(bbox)
200
+
201
+ image = draw_polygon(image, bbox, label)
202
+
203
+ if bboxes:
204
+ b1 = merge_rects(np.asarray(bboxes))
205
+ b1 = np.asarray(b1)
206
+ image = draw_polygon(image, b1, '0.1')
207
+
208
+ outf = f'./data/{output_prefix}_p{idx}_vis.png'
209
+ cv2.imwrite(outf, image)
210
+
211
+
212
+ def test_vis3():
213
+ file_path = './data/pdf_input/《中国药典》2020年版 一部.pdf'
214
+
215
+ start = 50
216
+ end = 60
217
+ n = end - start
218
+ output_prefix = 'zhongguoyaodian_2020_v1'
219
+
220
+ blob = Blob.from_path(file_path)
221
+ pages = None
222
+ image_blobs = []
223
+ with blob.as_bytes_io() as file_path:
224
+ pages = fitz.open(file_path)
225
+ pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
226
+ image_blobs = get_image_blobs(pages, pdf_reader, n, start=50)
227
+
228
+ assert len(image_blobs) == n
229
+
230
+ for i, blob in enumerate(image_blobs):
231
+ idx = i + start
232
+ page = pages.load_page(idx)
233
+
234
+ rect = page.rect
235
+ print('rect', rect)
236
+ o = 10
237
+ b0 = np.asarray([rect.x0 + o, rect.y0 + o, rect.x1 - o, rect.y1 - o])
238
+
239
+ bytes_arr = np.frombuffer(blob.as_bytes(), dtype=np.uint8)
240
+ image = cv2.imdecode(bytes_arr, flags=1)
241
+
242
+ image = draw_polygon(image, b0, '0.0')
243
+
244
+ rotation_matrix = np.asarray(page.rotation_matrix).reshape((3, 2))
245
+ c1 = (rotation_matrix[0, 0] - 1) <= 1e-6
246
+ c2 = (rotation_matrix[1, 1] - 1) <= 1e-6
247
+ is_rotated = c1 and c2
248
+
249
+ textpage = page.get_textpage()
250
+ blocks = textpage.extractBLOCKS()
251
+ IMG_BLOCK_TYPE = 1
252
+
253
+ # blocks = order_by_tbyx(blocks)
254
+ bboxes = []
255
+ for off, b in enumerate(blocks):
256
+ label = 'text' if b[-1] != IMG_BLOCK_TYPE else 'image'
257
+ label = f'{label}-{off}'
258
+ print('block', b, label)
259
+ bbox = np.asarray([b[0], b[1], b[2], b[3]])
260
+
261
+ aug_bbox = bbox.reshape((-1, 2))
262
+ padding = np.ones((len(aug_bbox), 1))
263
+ aug_bbox = np.hstack([aug_bbox, padding])
264
+ new_bbox = np.dot(aug_bbox, rotation_matrix).reshape(-1)
265
+
266
+ new_bbox = norm_rect(new_bbox)
267
+
268
+ print('new_bboxes', new_bbox)
269
+ bboxes.append(new_bbox)
270
+
271
+ image = draw_polygon(image, new_bbox, label)
272
+
273
+ print(bboxes)
274
+ if bboxes:
275
+ b1 = merge_rects(np.asarray(bboxes))
276
+ b1 = np.asarray(b1)
277
+ image = draw_polygon(image, b1, '0.1')
278
+
279
+ outf = f'./data/{output_prefix}_p{idx}_vis.png'
280
+ cv2.imwrite(outf, image)
281
+
282
+
283
+ # test_vis3()
284
+ # test_vis2()
285
+ test_vis()
286
+ # test()
@@ -0,0 +1,7 @@
1
+ from .host_embedding import BGEZhEmbedding, GTEEmbedding, HostEmbeddings, ME5Embedding
2
+ from .wenxin import WenxinEmbeddings
3
+
4
+ __all__ = [
5
+ 'WenxinEmbeddings', 'ME5Embedding', 'BGEZhEmbedding', 'GTEEmbedding',
6
+ 'HostEmbeddings'
7
+ ]
@@ -0,0 +1,133 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import requests
7
+ from langchain.embeddings.base import Embeddings
8
+ from langchain.utils import get_from_dict_or_env
9
+ from pydantic import BaseModel, Extra, Field, root_validator
10
+ from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_attempt,
11
+ wait_exponential)
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def _create_retry_decorator(
17
+ embeddings: HostEmbeddings) -> Callable[[Any], Any]:
18
+ min_seconds = 4
19
+ max_seconds = 10
20
+ # Wait 2^x * 1 second between each retry starting with
21
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
22
+ return retry(
23
+ reraise=True,
24
+ stop=stop_after_attempt(embeddings.max_retries),
25
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
26
+ retry=(retry_if_exception_type(Exception)),
27
+ before_sleep=before_sleep_log(logger, logging.WARNING),
28
+ )
29
+
30
+
31
+ def embed_with_retry(embeddings: HostEmbeddings, **kwargs: Any) -> Any:
32
+ """Use tenacity to retry the embedding call."""
33
+ retry_decorator = _create_retry_decorator(embeddings)
34
+
35
+ @retry_decorator
36
+ def _embed_with_retry(**kwargs: Any) -> Any:
37
+ return embeddings.embed(**kwargs)
38
+
39
+ return _embed_with_retry(**kwargs)
40
+
41
+
42
+ class HostEmbeddings(BaseModel, Embeddings):
43
+ """host embedding models.
44
+ """
45
+
46
+ client: Optional[Any] #: :meta private:
47
+ """Model name to use."""
48
+ model: str = 'embedding-host'
49
+ host_base_url: str = None
50
+
51
+ deployment: Optional[str] = 'default'
52
+
53
+ embedding_ctx_length: Optional[int] = 6144
54
+ """The maximum number of tokens to embed at once."""
55
+ """Maximum number of texts to embed in each batch"""
56
+ max_retries: Optional[int] = 6
57
+ """Maximum number of retries to make when generating."""
58
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None
59
+ """Timeout in seconds for the OpenAPI request."""
60
+
61
+ model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
62
+ """Holds any model parameters valid for `create` call not explicitly specified."""
63
+
64
+ verbose: Optional[bool] = False
65
+
66
+ class Config:
67
+ """Configuration for this pydantic object."""
68
+
69
+ extra = Extra.forbid
70
+
71
+ @root_validator()
72
+ def validate_environment(cls, values: Dict) -> Dict:
73
+ """Validate that api key and python package exists in environment."""
74
+ values['host_base_url'] = get_from_dict_or_env(values, 'host_base_url',
75
+ 'HostBaseUrl')
76
+
77
+ try:
78
+ values['client'] = requests.post
79
+ except AttributeError:
80
+ raise ValueError(
81
+ 'Try upgrading it with `pip install --upgrade requests`.')
82
+ return values
83
+
84
+ @property
85
+ def _invocation_params(self) -> Dict:
86
+ api_args = {
87
+ 'model': self.model,
88
+ 'request_timeout': self.request_timeout,
89
+ **self.model_kwargs,
90
+ }
91
+ return api_args
92
+
93
+ def embed(self, texts: List[str], **kwargs) -> List[List[float]]:
94
+ emb_type = kwargs.get('type', 'raw')
95
+ inp = {'texts': texts, 'model': self.model, 'type': emb_type}
96
+ if self.verbose:
97
+ print('payload', inp)
98
+
99
+ url = f'{self.host_base_url}/{self.model}/infer'
100
+ outp = self.client(url=url, json=inp).json()
101
+ if outp['status_code'] != 200:
102
+ raise ValueError(
103
+ f"API returned an error: {outp['status_message']}")
104
+ return outp['embeddings']
105
+
106
+ def embed_documents(self,
107
+ texts: List[str],
108
+ chunk_size: Optional[int] = 0) -> List[List[float]]:
109
+ if not texts:
110
+ return []
111
+ """Embed search docs."""
112
+ texts = [text for text in texts if text]
113
+ embeddings = embed_with_retry(self, texts=texts, type='doc')
114
+ return embeddings
115
+
116
+ def embed_query(self, text: str) -> List[float]:
117
+ embeddings = embed_with_retry(self, texts=[text], type='query')
118
+ return embeddings[0]
119
+
120
+
121
+ class ME5Embedding(HostEmbeddings):
122
+ model: str = 'multi-e5'
123
+ embedding_ctx_length: int = 512
124
+
125
+
126
+ class BGEZhEmbedding(HostEmbeddings):
127
+ model: str = 'bge-zh'
128
+ embedding_ctx_length: int = 512
129
+
130
+
131
+ class GTEEmbedding(HostEmbeddings):
132
+ model: str = 'gte'
133
+ embedding_ctx_length: int = 512
@@ -0,0 +1,3 @@
1
+ from .wenxin import EmbeddingClient as WenxinEmbeddingClient
2
+
3
+ __all__ = ['WenxinEmbeddingClient']
@@ -0,0 +1,23 @@
1
+ from typing import Any, Dict, List, Union
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class EmbeddingInput(BaseModel):
7
+ model: str
8
+ input: Union[str, List[str]]
9
+
10
+
11
+ class Embedding(BaseModel):
12
+ object: str = 'embedding'
13
+ embedding: List[float]
14
+ index: int
15
+
16
+
17
+ class EmbeddingOutput(BaseModel):
18
+ status_code: int
19
+ status_message: str = 'success'
20
+ object: str = None
21
+ data: List[Embedding] = []
22
+ model: str = None
23
+ usage: Dict[str, Any] = None
@@ -0,0 +1,86 @@
1
+ import json
2
+
3
+ import numpy as np
4
+ import requests
5
+ from requests.exceptions import HTTPError
6
+
7
+
8
+ def get_access_token(api_key, sec_key):
9
+ url = (f'https://aip.baidubce.com/oauth/2.0/token?'
10
+ f'grant_type=client_credentials'
11
+ f'&client_id={api_key}&client_secret={sec_key}')
12
+
13
+ payload = json.dumps('')
14
+ headers = {
15
+ 'Content-Type': 'application/json',
16
+ 'Accept': 'application/json'
17
+ }
18
+
19
+ response = requests.request('POST', url, headers=headers, data=payload)
20
+ return response.json().get('access_token')
21
+
22
+
23
+ class EmbeddingClient(object):
24
+
25
+ def __init__(self, api_key, sec_key, **kwargs):
26
+ self.api_key = api_key
27
+ self.sec_key = sec_key
28
+ self.ep_url = ('https://aip.baidubce.com/rpc/2.0/ai_custom/v1/'
29
+ 'wenxinworkshop/embeddings')
30
+ self.headers = {'Content-Type': 'application/json'}
31
+ self.max_text_tokens = 384
32
+ self.max_text_num = 16
33
+ self.drop_exceed_token = kwargs.get('drop_exceed_token', True)
34
+
35
+ def create(self, model, input, verbose=False, **kwargs):
36
+ texts = input
37
+ if isinstance(texts, str):
38
+ texts = [texts]
39
+
40
+ if self.drop_exceed_token:
41
+ texts = [t[:self.max_text_tokens] for t in texts]
42
+
43
+ cond = np.all([len(text) <= self.max_text_tokens for text in texts])
44
+ if not cond:
45
+ raise HTTPError('text exceed max token size 384')
46
+
47
+ token = get_access_token(self.api_key, self.sec_key)
48
+ endpoint = f'{self.ep_url}/{model}?access_token={token}'
49
+
50
+ def _call(sub_texts):
51
+ payload = json.dumps({'input': sub_texts})
52
+ response = requests.post(endpoint,
53
+ headers=self.headers,
54
+ data=payload)
55
+ status_message = 'success'
56
+ status_code = response.status_code
57
+ usage = {'prompt_tokens': 0, 'total_tokens': 0}
58
+ data = []
59
+ if status_code == 200:
60
+ try:
61
+ info = json.loads(response.text)
62
+ status_code = info.get('error_code', 200)
63
+ status_message = info.get('error_msg', status_message)
64
+ if status_code == 200:
65
+ data = info['data']
66
+ usage = info['usage']
67
+ else:
68
+ raise HTTPError(status_message)
69
+ except Exception as e:
70
+ raise HTTPError(str(e))
71
+ else:
72
+ raise HTTPError('requests error')
73
+ return data, usage
74
+
75
+ data = []
76
+ usage = {'prompt_tokens': 0, 'total_tokens': 0}
77
+
78
+ for i in range(0, len(texts), self.max_text_num):
79
+ sub_texts = texts[i:(i + self.max_text_num)]
80
+ sub_data, sub_usage = _call(sub_texts)
81
+ data.extend(sub_data)
82
+ usage['prompt_tokens'] += sub_usage['prompt_tokens']
83
+ usage['total_tokens'] += sub_usage['total_tokens']
84
+
85
+ outp = dict(status_code=200, model=model, data=data, usage=usage)
86
+ return outp
@@ -0,0 +1,139 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ # import warnings
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ # import numpy as np
8
+ from langchain.embeddings.base import Embeddings
9
+ from langchain.utils import get_from_dict_or_env
10
+ from pydantic import BaseModel, Extra, Field, root_validator
11
+ from requests.exceptions import HTTPError
12
+ from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_attempt,
13
+ wait_exponential)
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def _create_retry_decorator(
19
+ embeddings: WenxinEmbeddings) -> Callable[[Any], Any]:
20
+ min_seconds = 4
21
+ max_seconds = 10
22
+ # Wait 2^x * 1 second between each retry starting with
23
+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
24
+ return retry(
25
+ reraise=True,
26
+ stop=stop_after_attempt(embeddings.max_retries),
27
+ wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
28
+ retry=(retry_if_exception_type(HTTPError)),
29
+ before_sleep=before_sleep_log(logger, logging.WARNING),
30
+ )
31
+
32
+
33
+ def embed_with_retry(embeddings: WenxinEmbeddings, **kwargs: Any) -> Any:
34
+ """Use tenacity to retry the embedding call."""
35
+ retry_decorator = _create_retry_decorator(embeddings)
36
+
37
+ @retry_decorator
38
+ def _embed_with_retry(**kwargs: Any) -> Any:
39
+ return embeddings.embed(**kwargs)
40
+
41
+ return _embed_with_retry(**kwargs)
42
+
43
+
44
+ class WenxinEmbeddings(BaseModel, Embeddings):
45
+ """Wenxin embedding models.
46
+
47
+ To use, the environment variable ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY``
48
+ set with your API key or pass it as a named parameter to the constructor.
49
+
50
+ Example:
51
+ .. code-block:: python
52
+ from bisheng_langchain.embeddings import WenxinEmbeddings
53
+ wenxin_embeddings = WenxinEmbeddings(
54
+ wenxin_api_key="my-api-key",
55
+ wenxin_secret_key='xxx')
56
+
57
+ """
58
+
59
+ client: Optional[Any] #: :meta private:
60
+ model: str = 'embedding-v1'
61
+
62
+ deployment: Optional[str] = 'default'
63
+ wenxin_api_key: Optional[str] = None
64
+ wenxin_secret_key: Optional[str] = None
65
+
66
+ embedding_ctx_length: Optional[int] = 6144
67
+ """The maximum number of tokens to embed at once."""
68
+ """Maximum number of texts to embed in each batch"""
69
+ max_retries: Optional[int] = 6
70
+ """Maximum number of retries to make when generating."""
71
+ request_timeout: Optional[Union[float, Tuple[float, float]]] = None
72
+ """Timeout in seconds for the OpenAPI request."""
73
+
74
+ model_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict)
75
+ """Holds any model parameters valid for `create` call not explicitly specified."""
76
+
77
+ class Config:
78
+ """Configuration for this pydantic object."""
79
+
80
+ extra = Extra.forbid
81
+
82
+ @root_validator()
83
+ def validate_environment(cls, values: Dict) -> Dict:
84
+ """Validate that api key and python package exists in environment."""
85
+ values['wenxin_api_key'] = get_from_dict_or_env(
86
+ values, 'wenxin_api_key', 'WENXIN_API_KEY')
87
+ values['wenxin_secret_key'] = get_from_dict_or_env(
88
+ values,
89
+ 'wenxin_secret_key',
90
+ 'WENXIN_SECRET_KEY',
91
+ )
92
+
93
+ api_key = values['wenxin_api_key']
94
+ sec_key = values['wenxin_secret_key']
95
+ try:
96
+ from .interface import WenxinEmbeddingClient
97
+ values['client'] = WenxinEmbeddingClient(api_key=api_key,
98
+ sec_key=sec_key)
99
+ except AttributeError:
100
+ raise ValueError(
101
+ 'Try upgrading it with `pip install --upgrade requests`.')
102
+ return values
103
+
104
+ @property
105
+ def _invocation_params(self) -> Dict:
106
+ wenxin_args = {
107
+ 'model': self.model,
108
+ 'request_timeout': self.request_timeout,
109
+ **self.model_kwargs,
110
+ }
111
+
112
+ return wenxin_args
113
+
114
+ def embed(self, texts: List[str]) -> List[List[float]]:
115
+ inp = {'input': texts, 'model': self.model}
116
+ outp = self.client.create(**inp)
117
+ if outp['status_code'] != 200:
118
+ raise ValueError(
119
+ f"Wenxin API returned an error: {outp['status_message']}")
120
+ return [e['embedding'] for e in outp['data']]
121
+
122
+ def embed_documents(self,
123
+ texts: List[str],
124
+ chunk_size: Optional[int] = 0) -> List[List[float]]:
125
+ embeddings = embed_with_retry(self, texts=texts)
126
+ return embeddings
127
+
128
+ def embed_query(self, text: str) -> List[float]:
129
+ """Call out to OpenAI's embedding endpoint for embedding query text.
130
+
131
+ Args:
132
+ text: The text to embed.
133
+
134
+ Returns:
135
+ Embedding for the text.
136
+ """
137
+
138
+ embeddings = embed_with_retry(self, texts=[text])
139
+ return embeddings[0]
@@ -0,0 +1,3 @@
1
+ from .elastic_keywords_search import ElasticKeywordsSearch
2
+
3
+ __all__ = ['ElasticKeywordsSearch']