langroid 0.28.7__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/.chainlit/config.toml +121 -0
- langroid/agent/.chainlit/translations/bn.json +231 -0
- langroid/agent/.chainlit/translations/en-US.json +229 -0
- langroid/agent/.chainlit/translations/gu.json +231 -0
- langroid/agent/.chainlit/translations/he-IL.json +231 -0
- langroid/agent/.chainlit/translations/hi.json +231 -0
- langroid/agent/.chainlit/translations/kn.json +231 -0
- langroid/agent/.chainlit/translations/ml.json +231 -0
- langroid/agent/.chainlit/translations/mr.json +231 -0
- langroid/agent/.chainlit/translations/ta.json +231 -0
- langroid/agent/.chainlit/translations/te.json +231 -0
- langroid/agent/.chainlit/translations/zh-CN.json +229 -0
- langroid/embedding_models/__init__.py +6 -2
- langroid/embedding_models/base.py +4 -0
- langroid/embedding_models/models.py +151 -39
- langroid/language_models/azure_openai.py +60 -28
- langroid/language_models/openai_gpt.py +26 -19
- langroid/vector_store/chromadb.py +8 -0
- {langroid-0.28.7.dist-info → langroid-0.30.0.dist-info}/METADATA +1 -1
- {langroid-0.28.7.dist-info → langroid-0.30.0.dist-info}/RECORD +23 -12
- pyproject.toml +1 -1
- langroid/agent/team.py +0 -41
- {langroid-0.28.7.dist-info → langroid-0.30.0.dist-info}/LICENSE +0 -0
- {langroid-0.28.7.dist-info → langroid-0.30.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,229 @@
|
|
1
|
+
{
|
2
|
+
"components": {
|
3
|
+
"atoms": {
|
4
|
+
"buttons": {
|
5
|
+
"userButton": {
|
6
|
+
"menu": {
|
7
|
+
"settings": "\u8bbe\u7f6e",
|
8
|
+
"settingsKey": "S",
|
9
|
+
"APIKeys": "API \u5bc6\u94a5",
|
10
|
+
"logout": "\u767b\u51fa"
|
11
|
+
}
|
12
|
+
}
|
13
|
+
}
|
14
|
+
},
|
15
|
+
"molecules": {
|
16
|
+
"newChatButton": {
|
17
|
+
"newChat": "\u65b0\u5efa\u5bf9\u8bdd"
|
18
|
+
},
|
19
|
+
"tasklist": {
|
20
|
+
"TaskList": {
|
21
|
+
"title": "\ud83d\uddd2\ufe0f \u4efb\u52a1\u5217\u8868",
|
22
|
+
"loading": "\u52a0\u8f7d\u4e2d...",
|
23
|
+
"error": "\u53d1\u751f\u9519\u8bef"
|
24
|
+
}
|
25
|
+
},
|
26
|
+
"attachments": {
|
27
|
+
"cancelUpload": "\u53d6\u6d88\u4e0a\u4f20",
|
28
|
+
"removeAttachment": "\u79fb\u9664\u9644\u4ef6"
|
29
|
+
},
|
30
|
+
"newChatDialog": {
|
31
|
+
"createNewChat": "\u521b\u5efa\u65b0\u5bf9\u8bdd\uff1f",
|
32
|
+
"clearChat": "\u8fd9\u5c06\u6e05\u9664\u5f53\u524d\u6d88\u606f\u5e76\u5f00\u59cb\u65b0\u7684\u5bf9\u8bdd\u3002",
|
33
|
+
"cancel": "\u53d6\u6d88",
|
34
|
+
"confirm": "\u786e\u8ba4"
|
35
|
+
},
|
36
|
+
"settingsModal": {
|
37
|
+
"settings": "\u8bbe\u7f6e",
|
38
|
+
"expandMessages": "\u5c55\u5f00\u6d88\u606f",
|
39
|
+
"hideChainOfThought": "\u9690\u85cf\u601d\u8003\u94fe",
|
40
|
+
"darkMode": "\u6697\u8272\u6a21\u5f0f"
|
41
|
+
},
|
42
|
+
"detailsButton": {
|
43
|
+
"using": "\u4f7f\u7528",
|
44
|
+
"used": "\u5df2\u7528"
|
45
|
+
},
|
46
|
+
"auth": {
|
47
|
+
"authLogin": {
|
48
|
+
"title": "\u767b\u5f55\u4ee5\u8bbf\u95ee\u5e94\u7528\u3002",
|
49
|
+
"form": {
|
50
|
+
"email": "\u7535\u5b50\u90ae\u7bb1\u5730\u5740",
|
51
|
+
"password": "\u5bc6\u7801",
|
52
|
+
"noAccount": "\u6ca1\u6709\u8d26\u6237\uff1f",
|
53
|
+
"alreadyHaveAccount": "\u5df2\u6709\u8d26\u6237\uff1f",
|
54
|
+
"signup": "\u6ce8\u518c",
|
55
|
+
"signin": "\u767b\u5f55",
|
56
|
+
"or": "\u6216\u8005",
|
57
|
+
"continue": "\u7ee7\u7eed",
|
58
|
+
"forgotPassword": "\u5fd8\u8bb0\u5bc6\u7801\uff1f",
|
59
|
+
"passwordMustContain": "\u60a8\u7684\u5bc6\u7801\u5fc5\u987b\u5305\u542b\uff1a",
|
60
|
+
"emailRequired": "\u7535\u5b50\u90ae\u7bb1\u662f\u5fc5\u586b\u9879",
|
61
|
+
"passwordRequired": "\u5bc6\u7801\u662f\u5fc5\u586b\u9879"
|
62
|
+
},
|
63
|
+
"error": {
|
64
|
+
"default": "\u65e0\u6cd5\u767b\u5f55\u3002",
|
65
|
+
"signin": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
|
66
|
+
"oauthsignin": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
|
67
|
+
"redirect_uri_mismatch": "\u91cd\u5b9a\u5411URI\u4e0eOAuth\u5e94\u7528\u914d\u7f6e\u4e0d\u5339\u914d\u3002",
|
68
|
+
"oauthcallbackerror": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
|
69
|
+
"oauthcreateaccount": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
|
70
|
+
"emailcreateaccount": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
|
71
|
+
"callback": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
|
72
|
+
"oauthaccountnotlinked": "\u4e3a\u4e86\u9a8c\u8bc1\u60a8\u7684\u8eab\u4efd\uff0c\u8bf7\u4f7f\u7528\u6700\u521d\u4f7f\u7528\u7684\u540c\u4e00\u8d26\u6237\u767b\u5f55\u3002",
|
73
|
+
"emailsignin": "\u65e0\u6cd5\u53d1\u9001\u90ae\u4ef6\u3002",
|
74
|
+
"emailverify": "\u8bf7\u9a8c\u8bc1\u60a8\u7684\u7535\u5b50\u90ae\u4ef6\uff0c\u5df2\u53d1\u9001\u4e00\u5c01\u65b0\u90ae\u4ef6\u3002",
|
75
|
+
"credentialssignin": "\u767b\u5f55\u5931\u8d25\u3002\u8bf7\u68c0\u67e5\u60a8\u63d0\u4f9b\u7684\u8be6\u7ec6\u4fe1\u606f\u662f\u5426\u6b63\u786e\u3002",
|
76
|
+
"sessionrequired": "\u8bf7\u767b\u5f55\u4ee5\u8bbf\u95ee\u6b64\u9875\u9762\u3002"
|
77
|
+
}
|
78
|
+
},
|
79
|
+
"authVerifyEmail": {
|
80
|
+
"almostThere": "\u60a8\u5feb\u6210\u529f\u4e86\uff01\u6211\u4eec\u5df2\u5411 ",
|
81
|
+
"verifyEmailLink": "\u8bf7\u5355\u51fb\u8be5\u90ae\u4ef6\u4e2d\u7684\u94fe\u63a5\u4ee5\u5b8c\u6210\u6ce8\u518c\u3002",
|
82
|
+
"didNotReceive": "\u6ca1\u627e\u5230\u90ae\u4ef6\uff1f",
|
83
|
+
"resendEmail": "\u91cd\u65b0\u53d1\u9001\u90ae\u4ef6",
|
84
|
+
"goBack": "\u8fd4\u56de",
|
85
|
+
"emailSent": "\u90ae\u4ef6\u5df2\u6210\u529f\u53d1\u9001\u3002",
|
86
|
+
"verifyEmail": "\u9a8c\u8bc1\u60a8\u7684\u7535\u5b50\u90ae\u4ef6\u5730\u5740"
|
87
|
+
},
|
88
|
+
"providerButton": {
|
89
|
+
"continue": "\u4f7f\u7528{{provider}}\u7ee7\u7eed",
|
90
|
+
"signup": "\u4f7f\u7528{{provider}}\u6ce8\u518c"
|
91
|
+
},
|
92
|
+
"authResetPassword": {
|
93
|
+
"newPasswordRequired": "\u65b0\u5bc6\u7801\u662f\u5fc5\u586b\u9879",
|
94
|
+
"passwordsMustMatch": "\u5bc6\u7801\u5fc5\u987b\u4e00\u81f4",
|
95
|
+
"confirmPasswordRequired": "\u786e\u8ba4\u5bc6\u7801\u662f\u5fc5\u586b\u9879",
|
96
|
+
"newPassword": "\u65b0\u5bc6\u7801",
|
97
|
+
"confirmPassword": "\u786e\u8ba4\u5bc6\u7801",
|
98
|
+
"resetPassword": "\u91cd\u7f6e\u5bc6\u7801"
|
99
|
+
},
|
100
|
+
"authForgotPassword": {
|
101
|
+
"email": "\u7535\u5b50\u90ae\u7bb1\u5730\u5740",
|
102
|
+
"emailRequired": "\u7535\u5b50\u90ae\u7bb1\u662f\u5fc5\u586b\u9879",
|
103
|
+
"emailSent": "\u8bf7\u68c0\u67e5\u7535\u5b50\u90ae\u7bb1{{email}}\u4ee5\u83b7\u53d6\u91cd\u7f6e\u5bc6\u7801\u7684\u6307\u793a\u3002",
|
104
|
+
"enterEmail": "\u8bf7\u8f93\u5165\u60a8\u7684\u7535\u5b50\u90ae\u7bb1\u5730\u5740\uff0c\u6211\u4eec\u5c06\u53d1\u9001\u91cd\u7f6e\u5bc6\u7801\u7684\u6307\u793a\u3002",
|
105
|
+
"resendEmail": "\u91cd\u65b0\u53d1\u9001\u90ae\u4ef6",
|
106
|
+
"continue": "\u7ee7\u7eed",
|
107
|
+
"goBack": "\u8fd4\u56de"
|
108
|
+
}
|
109
|
+
}
|
110
|
+
},
|
111
|
+
"organisms": {
|
112
|
+
"chat": {
|
113
|
+
"history": {
|
114
|
+
"index": {
|
115
|
+
"showHistory": "\u663e\u793a\u5386\u53f2",
|
116
|
+
"lastInputs": "\u6700\u540e\u8f93\u5165",
|
117
|
+
"noInputs": "\u5982\u6b64\u7a7a\u65f7...",
|
118
|
+
"loading": "\u52a0\u8f7d\u4e2d..."
|
119
|
+
}
|
120
|
+
},
|
121
|
+
"inputBox": {
|
122
|
+
"input": {
|
123
|
+
"placeholder": "\u5728\u8fd9\u91cc\u8f93\u5165\u60a8\u7684\u6d88\u606f..."
|
124
|
+
},
|
125
|
+
"speechButton": {
|
126
|
+
"start": "\u5f00\u59cb\u5f55\u97f3",
|
127
|
+
"stop": "\u505c\u6b62\u5f55\u97f3"
|
128
|
+
},
|
129
|
+
"SubmitButton": {
|
130
|
+
"sendMessage": "\u53d1\u9001\u6d88\u606f",
|
131
|
+
"stopTask": "\u505c\u6b62\u4efb\u52a1"
|
132
|
+
},
|
133
|
+
"UploadButton": {
|
134
|
+
"attachFiles": "\u9644\u52a0\u6587\u4ef6"
|
135
|
+
},
|
136
|
+
"waterMark": {
|
137
|
+
"text": "\u4f7f\u7528"
|
138
|
+
}
|
139
|
+
},
|
140
|
+
"Messages": {
|
141
|
+
"index": {
|
142
|
+
"running": "\u8fd0\u884c\u4e2d",
|
143
|
+
"executedSuccessfully": "\u6267\u884c\u6210\u529f",
|
144
|
+
"failed": "\u5931\u8d25",
|
145
|
+
"feedbackUpdated": "\u53cd\u9988\u66f4\u65b0",
|
146
|
+
"updating": "\u6b63\u5728\u66f4\u65b0"
|
147
|
+
}
|
148
|
+
},
|
149
|
+
"dropScreen": {
|
150
|
+
"dropYourFilesHere": "\u5728\u8fd9\u91cc\u62d6\u653e\u60a8\u7684\u6587\u4ef6"
|
151
|
+
},
|
152
|
+
"index": {
|
153
|
+
"failedToUpload": "\u4e0a\u4f20\u5931\u8d25",
|
154
|
+
"cancelledUploadOf": "\u53d6\u6d88\u4e0a\u4f20",
|
155
|
+
"couldNotReachServer": "\u65e0\u6cd5\u8fde\u63a5\u5230\u670d\u52a1\u5668",
|
156
|
+
"continuingChat": "\u7ee7\u7eed\u4e4b\u524d\u7684\u5bf9\u8bdd"
|
157
|
+
},
|
158
|
+
"settings": {
|
159
|
+
"settingsPanel": "\u8bbe\u7f6e\u9762\u677f",
|
160
|
+
"reset": "\u91cd\u7f6e",
|
161
|
+
"cancel": "\u53d6\u6d88",
|
162
|
+
"confirm": "\u786e\u8ba4"
|
163
|
+
}
|
164
|
+
},
|
165
|
+
"threadHistory": {
|
166
|
+
"sidebar": {
|
167
|
+
"filters": {
|
168
|
+
"FeedbackSelect": {
|
169
|
+
"feedbackAll": "\u53cd\u9988\uff1a\u5168\u90e8",
|
170
|
+
"feedbackPositive": "\u53cd\u9988\uff1a\u6b63\u9762",
|
171
|
+
"feedbackNegative": "\u53cd\u9988\uff1a\u8d1f\u9762"
|
172
|
+
},
|
173
|
+
"SearchBar": {
|
174
|
+
"search": "\u641c\u7d22"
|
175
|
+
}
|
176
|
+
},
|
177
|
+
"DeleteThreadButton": {
|
178
|
+
"confirmMessage": "\u8fd9\u5c06\u5220\u9664\u7ebf\u7a0b\u53ca\u5176\u6d88\u606f\u548c\u5143\u7d20\u3002",
|
179
|
+
"cancel": "\u53d6\u6d88",
|
180
|
+
"confirm": "\u786e\u8ba4",
|
181
|
+
"deletingChat": "\u5220\u9664\u5bf9\u8bdd",
|
182
|
+
"chatDeleted": "\u5bf9\u8bdd\u5df2\u5220\u9664"
|
183
|
+
},
|
184
|
+
"index": {
|
185
|
+
"pastChats": "\u8fc7\u5f80\u5bf9\u8bdd"
|
186
|
+
},
|
187
|
+
"ThreadList": {
|
188
|
+
"empty": "\u7a7a\u7684...",
|
189
|
+
"today": "\u4eca\u5929",
|
190
|
+
"yesterday": "\u6628\u5929",
|
191
|
+
"previous7days": "\u524d7\u5929",
|
192
|
+
"previous30days": "\u524d30\u5929"
|
193
|
+
},
|
194
|
+
"TriggerButton": {
|
195
|
+
"closeSidebar": "\u5173\u95ed\u4fa7\u8fb9\u680f",
|
196
|
+
"openSidebar": "\u6253\u5f00\u4fa7\u8fb9\u680f"
|
197
|
+
}
|
198
|
+
},
|
199
|
+
"Thread": {
|
200
|
+
"backToChat": "\u8fd4\u56de\u5bf9\u8bdd",
|
201
|
+
"chatCreatedOn": "\u6b64\u5bf9\u8bdd\u521b\u5efa\u4e8e"
|
202
|
+
}
|
203
|
+
},
|
204
|
+
"header": {
|
205
|
+
"chat": "\u5bf9\u8bdd",
|
206
|
+
"readme": "\u8bf4\u660e"
|
207
|
+
}
|
208
|
+
}
|
209
|
+
},
|
210
|
+
"hooks": {
|
211
|
+
"useLLMProviders": {
|
212
|
+
"failedToFetchProviders": "\u83b7\u53d6\u63d0\u4f9b\u8005\u5931\u8d25:"
|
213
|
+
}
|
214
|
+
},
|
215
|
+
"pages": {
|
216
|
+
"Design": {},
|
217
|
+
"Env": {
|
218
|
+
"savedSuccessfully": "\u4fdd\u5b58\u6210\u529f",
|
219
|
+
"requiredApiKeys": "\u5fc5\u9700\u7684API\u5bc6\u94a5",
|
220
|
+
"requiredApiKeysInfo": "\u8981\u4f7f\u7528\u6b64\u5e94\u7528\uff0c\u9700\u8981\u4ee5\u4e0bAPI\u5bc6\u94a5\u3002\u8fd9\u4e9b\u5bc6\u94a5\u5b58\u50a8\u5728\u60a8\u7684\u8bbe\u5907\u672c\u5730\u5b58\u50a8\u4e2d\u3002"
|
221
|
+
},
|
222
|
+
"Page": {
|
223
|
+
"notPartOfProject": "\u60a8\u4e0d\u662f\u6b64\u9879\u76ee\u7684\u4e00\u90e8\u5206\u3002"
|
224
|
+
},
|
225
|
+
"ResumeButton": {
|
226
|
+
"resumeChat": "\u6062\u590d\u5bf9\u8bdd"
|
227
|
+
}
|
228
|
+
}
|
229
|
+
}
|
@@ -9,8 +9,10 @@ from .base import (
|
|
9
9
|
from .models import (
|
10
10
|
OpenAIEmbeddings,
|
11
11
|
OpenAIEmbeddingsConfig,
|
12
|
-
SentenceTransformerEmbeddingsConfig,
|
13
12
|
SentenceTransformerEmbeddings,
|
13
|
+
SentenceTransformerEmbeddingsConfig,
|
14
|
+
LlamaCppServerEmbeddings,
|
15
|
+
LlamaCppServerEmbeddingsConfig,
|
14
16
|
embedding_model,
|
15
17
|
)
|
16
18
|
from .remote_embeds import (
|
@@ -27,8 +29,10 @@ __all__ = [
|
|
27
29
|
"EmbeddingModelsConfig",
|
28
30
|
"OpenAIEmbeddings",
|
29
31
|
"OpenAIEmbeddingsConfig",
|
30
|
-
"SentenceTransformerEmbeddingsConfig",
|
31
32
|
"SentenceTransformerEmbeddings",
|
33
|
+
"SentenceTransformerEmbeddingsConfig",
|
34
|
+
"LlamaCppServerEmbeddings",
|
35
|
+
"LlamaCppServerEmbeddingsConfig",
|
32
36
|
"embedding_model",
|
33
37
|
"RemoteEmbeddingsConfig",
|
34
38
|
"RemoteEmbeddings",
|
@@ -26,6 +26,8 @@ class EmbeddingModel(ABC):
|
|
26
26
|
from langroid.embedding_models.models import (
|
27
27
|
FastEmbedEmbeddings,
|
28
28
|
FastEmbedEmbeddingsConfig,
|
29
|
+
LlamaCppServerEmbeddings,
|
30
|
+
LlamaCppServerEmbeddingsConfig,
|
29
31
|
OpenAIEmbeddings,
|
30
32
|
OpenAIEmbeddingsConfig,
|
31
33
|
SentenceTransformerEmbeddings,
|
@@ -44,6 +46,8 @@ class EmbeddingModel(ABC):
|
|
44
46
|
return SentenceTransformerEmbeddings(config)
|
45
47
|
elif isinstance(config, FastEmbedEmbeddingsConfig):
|
46
48
|
return FastEmbedEmbeddings(config)
|
49
|
+
elif isinstance(config, LlamaCppServerEmbeddingsConfig):
|
50
|
+
return LlamaCppServerEmbeddings(config)
|
47
51
|
else:
|
48
52
|
raise ValueError(f"Unknown embedding config: {config.__repr_name__}")
|
49
53
|
|
@@ -3,6 +3,7 @@ import os
|
|
3
3
|
from functools import cached_property
|
4
4
|
from typing import Any, Callable, Dict, List, Optional
|
5
5
|
|
6
|
+
import requests
|
6
7
|
import tiktoken
|
7
8
|
from dotenv import load_dotenv
|
8
9
|
from openai import OpenAI
|
@@ -48,13 +49,19 @@ class FastEmbedEmbeddingsConfig(EmbeddingModelsConfig):
|
|
48
49
|
additional_kwargs: Dict[str, Any] = {}
|
49
50
|
|
50
51
|
|
52
|
+
class LlamaCppServerEmbeddingsConfig(EmbeddingModelsConfig):
|
53
|
+
api_base: str = ""
|
54
|
+
context_length: int = 2048
|
55
|
+
batch_size: int = 2048
|
56
|
+
|
57
|
+
|
51
58
|
class EmbeddingFunctionCallable:
|
52
59
|
"""
|
53
60
|
A callable class designed to generate embeddings for a list of texts using
|
54
61
|
the OpenAI API, with automatic retries on failure.
|
55
62
|
|
56
63
|
Attributes:
|
57
|
-
|
64
|
+
embed_model (EmbeddingModel): An instance of EmbeddingModel that provides
|
58
65
|
configuration and utilities for generating embeddings.
|
59
66
|
|
60
67
|
Methods:
|
@@ -62,7 +69,7 @@ class EmbeddingFunctionCallable:
|
|
62
69
|
a list of input texts.
|
63
70
|
"""
|
64
71
|
|
65
|
-
def __init__(self,
|
72
|
+
def __init__(self, embed_model: EmbeddingModel, batch_size: int = 512):
|
66
73
|
"""
|
67
74
|
Initialize the EmbeddingFunctionCallable with a specific model.
|
68
75
|
|
@@ -71,7 +78,7 @@ class EmbeddingFunctionCallable:
|
|
71
78
|
generating embeddings.
|
72
79
|
batch_size (int): Batch size
|
73
80
|
"""
|
74
|
-
self.
|
81
|
+
self.embed_model = embed_model
|
75
82
|
self.batch_size = batch_size
|
76
83
|
|
77
84
|
def __call__(self, input: List[str]) -> Embeddings:
|
@@ -91,14 +98,46 @@ class EmbeddingFunctionCallable:
|
|
91
98
|
Returns:
|
92
99
|
Embeddings: A list of embedding vectors corresponding to the input texts.
|
93
100
|
"""
|
94
|
-
tokenized_texts = self.model.truncate_texts(input)
|
95
101
|
embeds = []
|
96
|
-
|
97
|
-
|
98
|
-
|
102
|
+
|
103
|
+
if isinstance(self.embed_model, OpenAIEmbeddings):
|
104
|
+
tokenized_texts = self.embed_model.truncate_texts(input)
|
105
|
+
|
106
|
+
for batch in batched(tokenized_texts, self.batch_size):
|
107
|
+
result = self.embed_model.client.embeddings.create(
|
108
|
+
input=batch, model=self.embed_model.config.model_name
|
109
|
+
)
|
110
|
+
batch_embeds = [d.embedding for d in result.data]
|
111
|
+
embeds.extend(batch_embeds)
|
112
|
+
|
113
|
+
elif isinstance(self.embed_model, SentenceTransformerEmbeddings):
|
114
|
+
if self.embed_model.config.data_parallel:
|
115
|
+
embeds = self.embed_model.model.encode_multi_process(
|
116
|
+
input,
|
117
|
+
self.embed_model.pool,
|
118
|
+
batch_size=self.batch_size,
|
119
|
+
).tolist()
|
120
|
+
else:
|
121
|
+
for str_batch in batched(input, self.batch_size):
|
122
|
+
batch_embeds = self.embed_model.model.encode(
|
123
|
+
str_batch, convert_to_numpy=True
|
124
|
+
).tolist() # type: ignore
|
125
|
+
embeds.extend(batch_embeds)
|
126
|
+
|
127
|
+
elif isinstance(self.embed_model, FastEmbedEmbeddings):
|
128
|
+
embeddings = self.embed_model.model.embed(
|
129
|
+
input, batch_size=self.batch_size, parallel=self.embed_model.parallel
|
99
130
|
)
|
100
|
-
|
101
|
-
embeds.
|
131
|
+
|
132
|
+
embeds = [embedding.tolist() for embedding in embeddings]
|
133
|
+
elif isinstance(self.embed_model, LlamaCppServerEmbeddings):
|
134
|
+
for input_string in input:
|
135
|
+
tokenized_text = self.embed_model.tokenize_string(input_string)
|
136
|
+
for token_batch in batched(tokenized_text, self.batch_size):
|
137
|
+
gen_embedding = self.embed_model.generate_embedding(
|
138
|
+
self.embed_model.detokenize_string(list(token_batch))
|
139
|
+
)
|
140
|
+
embeds.append(gen_embedding)
|
102
141
|
return embeds
|
103
142
|
|
104
143
|
|
@@ -176,24 +215,7 @@ class SentenceTransformerEmbeddings(EmbeddingModel):
|
|
176
215
|
self.config.context_length = self.tokenizer.model_max_length
|
177
216
|
|
178
217
|
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
179
|
-
|
180
|
-
if self.config.data_parallel:
|
181
|
-
embeds: Embeddings = self.model.encode_multi_process(
|
182
|
-
texts,
|
183
|
-
self.pool,
|
184
|
-
batch_size=self.config.batch_size,
|
185
|
-
).tolist()
|
186
|
-
else:
|
187
|
-
embeds = []
|
188
|
-
for batch in batched(texts, self.config.batch_size):
|
189
|
-
batch_embeds = self.model.encode(
|
190
|
-
batch, convert_to_numpy=True
|
191
|
-
).tolist() # type: ignore
|
192
|
-
embeds.extend(batch_embeds)
|
193
|
-
|
194
|
-
return embeds
|
195
|
-
|
196
|
-
return fn
|
218
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
197
219
|
|
198
220
|
@property
|
199
221
|
def embedding_dims(self) -> int:
|
@@ -214,10 +236,10 @@ class FastEmbedEmbeddings(EmbeddingModel):
|
|
214
236
|
|
215
237
|
super().__init__()
|
216
238
|
self.config = config
|
217
|
-
self.
|
218
|
-
self.
|
239
|
+
self.batch_size = config.batch_size
|
240
|
+
self.parallel = config.parallel
|
219
241
|
|
220
|
-
self.
|
242
|
+
self.model = TextEmbedding(
|
221
243
|
model_name=self.config.model_name,
|
222
244
|
cache_dir=self.config.cache_dir,
|
223
245
|
threads=self.config.threads,
|
@@ -225,14 +247,7 @@ class FastEmbedEmbeddings(EmbeddingModel):
|
|
225
247
|
)
|
226
248
|
|
227
249
|
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
228
|
-
|
229
|
-
embeddings = self._model.embed(
|
230
|
-
texts, batch_size=self._batch_size, parallel=self._parallel
|
231
|
-
)
|
232
|
-
|
233
|
-
return [embedding.tolist() for embedding in embeddings]
|
234
|
-
|
235
|
-
return fn
|
250
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
236
251
|
|
237
252
|
@cached_property
|
238
253
|
def embedding_dims(self) -> int:
|
@@ -240,10 +255,105 @@ class FastEmbedEmbeddings(EmbeddingModel):
|
|
240
255
|
return len(embed_func(["text"])[0])
|
241
256
|
|
242
257
|
|
258
|
+
LCSEC = LlamaCppServerEmbeddingsConfig
|
259
|
+
|
260
|
+
|
261
|
+
class LlamaCppServerEmbeddings(EmbeddingModel):
|
262
|
+
def __init__(self, config: LCSEC = LCSEC()):
|
263
|
+
super().__init__()
|
264
|
+
self.config = config
|
265
|
+
|
266
|
+
if self.config.api_base == "":
|
267
|
+
raise ValueError(
|
268
|
+
"""Api Base MUST be set for Llama Server Embeddings.
|
269
|
+
"""
|
270
|
+
)
|
271
|
+
|
272
|
+
self.tokenize_url = self.config.api_base + "/tokenize"
|
273
|
+
self.detokenize_url = self.config.api_base + "/detokenize"
|
274
|
+
self.embedding_url = self.config.api_base + "/embeddings"
|
275
|
+
|
276
|
+
def tokenize_string(self, text: str) -> List[int]:
|
277
|
+
data = {"content": text, "add_special": False, "with_pieces": False}
|
278
|
+
response = requests.post(self.tokenize_url, json=data)
|
279
|
+
|
280
|
+
if response.status_code == 200:
|
281
|
+
tokens = response.json()["tokens"]
|
282
|
+
if not (isinstance(tokens, list) and isinstance(tokens[0], (int, float))):
|
283
|
+
# not all(isinstance(token, (int, float)) for token in tokens):
|
284
|
+
raise ValueError(
|
285
|
+
"""Tokenizer endpoint has not returned the correct format.
|
286
|
+
Is the URL correct?
|
287
|
+
"""
|
288
|
+
)
|
289
|
+
return tokens
|
290
|
+
else:
|
291
|
+
raise requests.HTTPError(
|
292
|
+
self.tokenize_url,
|
293
|
+
response.status_code,
|
294
|
+
"Failed to connect to tokenization provider",
|
295
|
+
)
|
296
|
+
|
297
|
+
def detokenize_string(self, tokens: List[int]) -> str:
|
298
|
+
data = {"tokens": tokens}
|
299
|
+
response = requests.post(self.detokenize_url, json=data)
|
300
|
+
|
301
|
+
if response.status_code == 200:
|
302
|
+
text = response.json()["content"]
|
303
|
+
if not isinstance(text, str):
|
304
|
+
raise ValueError(
|
305
|
+
"""Deokenizer endpoint has not returned the correct format.
|
306
|
+
Is the URL correct?
|
307
|
+
"""
|
308
|
+
)
|
309
|
+
return text
|
310
|
+
else:
|
311
|
+
raise requests.HTTPError(
|
312
|
+
self.detokenize_url,
|
313
|
+
response.status_code,
|
314
|
+
"Failed to connect to detokenization provider",
|
315
|
+
)
|
316
|
+
|
317
|
+
def truncate_string_to_context_size(self, text: str) -> str:
|
318
|
+
tokens = self.tokenize_string(text)
|
319
|
+
tokens = tokens[: self.config.context_length]
|
320
|
+
return self.detokenize_string(tokens)
|
321
|
+
|
322
|
+
def generate_embedding(self, text: str) -> List[int | float]:
|
323
|
+
data = {"content": text}
|
324
|
+
response = requests.post(self.embedding_url, json=data)
|
325
|
+
|
326
|
+
if response.status_code == 200:
|
327
|
+
embeddings = response.json()["embedding"]
|
328
|
+
if not (
|
329
|
+
isinstance(embeddings, list) and isinstance(embeddings[0], (int, float))
|
330
|
+
):
|
331
|
+
raise ValueError(
|
332
|
+
"""Embedding endpoint has not returned the correct format.
|
333
|
+
Is the URL correct?
|
334
|
+
"""
|
335
|
+
)
|
336
|
+
return embeddings
|
337
|
+
else:
|
338
|
+
raise requests.HTTPError(
|
339
|
+
self.embedding_url,
|
340
|
+
response.status_code,
|
341
|
+
"Failed to connect to embedding provider",
|
342
|
+
)
|
343
|
+
|
344
|
+
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
345
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
346
|
+
|
347
|
+
@property
|
348
|
+
def embedding_dims(self) -> int:
|
349
|
+
return self.config.dims
|
350
|
+
|
351
|
+
|
243
352
|
def embedding_model(embedding_fn_type: str = "openai") -> EmbeddingModel:
|
244
353
|
"""
|
245
354
|
Args:
|
246
|
-
embedding_fn_type: "openai" or "
|
355
|
+
embedding_fn_type: "openai" or "fastembed" or
|
356
|
+
"llamacppserver" or "sentencetransformer" # others soon
|
247
357
|
Returns:
|
248
358
|
EmbeddingModel
|
249
359
|
"""
|
@@ -251,5 +361,7 @@ def embedding_model(embedding_fn_type: str = "openai") -> EmbeddingModel:
|
|
251
361
|
return OpenAIEmbeddings # type: ignore
|
252
362
|
elif embedding_fn_type == "fastembed":
|
253
363
|
return FastEmbedEmbeddings # type: ignore
|
364
|
+
elif embedding_fn_type == "llamacppserver":
|
365
|
+
return LlamaCppServerEmbeddings # type: ignore
|
254
366
|
else: # default sentence transformer
|
255
367
|
return SentenceTransformerEmbeddings # type: ignore
|
@@ -1,3 +1,6 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Callable
|
3
|
+
|
1
4
|
from dotenv import load_dotenv
|
2
5
|
from httpx import Timeout
|
3
6
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
@@ -15,6 +18,8 @@ azureStructuredOutputList = [
|
|
15
18
|
|
16
19
|
azureStructuredOutputAPIMin = "2024-08-01-preview"
|
17
20
|
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
18
23
|
|
19
24
|
class AzureConfig(OpenAIGPTConfig):
|
20
25
|
"""
|
@@ -42,6 +47,10 @@ class AzureConfig(OpenAIGPTConfig):
|
|
42
47
|
model_version: str = "" # is used to determine the cost of using the model
|
43
48
|
api_base: str = ""
|
44
49
|
|
50
|
+
# Alternatively, bring your own clients:
|
51
|
+
azure_openai_client_provider: Callable[[], AzureOpenAI] | None = None
|
52
|
+
azure_openai_async_client_provider: Callable[[], AsyncAzureOpenAI] | None = None
|
53
|
+
|
45
54
|
# all of the vars above can be set via env vars,
|
46
55
|
# by upper-casing the name and prefixing with `env_prefix`, e.g.
|
47
56
|
# AZURE_OPENAI_API_VERSION=2023-05-15
|
@@ -69,20 +78,6 @@ class AzureGPT(OpenAIGPT):
|
|
69
78
|
load_dotenv()
|
70
79
|
super().__init__(config)
|
71
80
|
self.config: AzureConfig = config
|
72
|
-
if self.config.api_key == "":
|
73
|
-
raise ValueError(
|
74
|
-
"""
|
75
|
-
AZURE_OPENAI_API_KEY not set in .env file,
|
76
|
-
please set it to your Azure API key."""
|
77
|
-
)
|
78
|
-
|
79
|
-
if self.config.api_base == "":
|
80
|
-
raise ValueError(
|
81
|
-
"""
|
82
|
-
AZURE_OPENAI_API_BASE not set in .env file,
|
83
|
-
please set it to your Azure API key."""
|
84
|
-
)
|
85
|
-
|
86
81
|
if self.config.deployment_name == "":
|
87
82
|
raise ValueError(
|
88
83
|
"""
|
@@ -98,6 +93,57 @@ class AzureGPT(OpenAIGPT):
|
|
98
93
|
please set it to chat model name in your deployment."""
|
99
94
|
)
|
100
95
|
|
96
|
+
if (
|
97
|
+
self.config.azure_openai_client_provider
|
98
|
+
or self.config.azure_openai_async_client_provider
|
99
|
+
):
|
100
|
+
if not self.config.azure_openai_client_provider:
|
101
|
+
self.client = None
|
102
|
+
logger.warning(
|
103
|
+
"Using user-provided Azure OpenAI client, but only async "
|
104
|
+
"client has been provided. Synchronous calls will fail."
|
105
|
+
)
|
106
|
+
if not self.config.azure_openai_async_client_provider:
|
107
|
+
self.async_client = None
|
108
|
+
logger.warning(
|
109
|
+
"Using user-provided Azure OpenAI client, but no async "
|
110
|
+
"client has been provided. Asynchronous calls will fail."
|
111
|
+
)
|
112
|
+
|
113
|
+
if self.config.azure_openai_client_provider:
|
114
|
+
self.client = self.config.azure_openai_client_provider()
|
115
|
+
if self.config.azure_openai_async_client_provider:
|
116
|
+
self.async_client = self.config.azure_openai_async_client_provider()
|
117
|
+
self.async_client.timeout = Timeout(self.config.timeout)
|
118
|
+
else:
|
119
|
+
if self.config.api_key == "":
|
120
|
+
raise ValueError(
|
121
|
+
"""
|
122
|
+
AZURE_OPENAI_API_KEY not set in .env file,
|
123
|
+
please set it to your Azure API key."""
|
124
|
+
)
|
125
|
+
|
126
|
+
if self.config.api_base == "":
|
127
|
+
raise ValueError(
|
128
|
+
"""
|
129
|
+
AZURE_OPENAI_API_BASE not set in .env file,
|
130
|
+
please set it to your Azure API key."""
|
131
|
+
)
|
132
|
+
|
133
|
+
self.client = AzureOpenAI(
|
134
|
+
api_key=self.config.api_key,
|
135
|
+
azure_endpoint=self.config.api_base,
|
136
|
+
api_version=self.config.api_version,
|
137
|
+
azure_deployment=self.config.deployment_name,
|
138
|
+
)
|
139
|
+
self.async_client = AsyncAzureOpenAI(
|
140
|
+
api_key=self.config.api_key,
|
141
|
+
azure_endpoint=self.config.api_base,
|
142
|
+
api_version=self.config.api_version,
|
143
|
+
azure_deployment=self.config.deployment_name,
|
144
|
+
timeout=Timeout(self.config.timeout),
|
145
|
+
)
|
146
|
+
|
101
147
|
# set the chat model to be the same as the model_name
|
102
148
|
# This corresponds to the gpt model you chose for your deployment
|
103
149
|
# when you deployed a model
|
@@ -108,20 +154,6 @@ class AzureGPT(OpenAIGPT):
|
|
108
154
|
and self.config.model_version in azureStructuredOutputList
|
109
155
|
)
|
110
156
|
|
111
|
-
self.client = AzureOpenAI(
|
112
|
-
api_key=self.config.api_key,
|
113
|
-
azure_endpoint=self.config.api_base,
|
114
|
-
api_version=self.config.api_version,
|
115
|
-
azure_deployment=self.config.deployment_name,
|
116
|
-
)
|
117
|
-
self.async_client = AsyncAzureOpenAI(
|
118
|
-
api_key=self.config.api_key,
|
119
|
-
azure_endpoint=self.config.api_base,
|
120
|
-
api_version=self.config.api_version,
|
121
|
-
azure_deployment=self.config.deployment_name,
|
122
|
-
timeout=Timeout(self.config.timeout),
|
123
|
-
)
|
124
|
-
|
125
157
|
def set_chat_model(self) -> None:
|
126
158
|
"""
|
127
159
|
Sets the chat model configuration based on the model name specified in the
|