langroid 0.28.7__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,229 @@
1
+ {
2
+ "components": {
3
+ "atoms": {
4
+ "buttons": {
5
+ "userButton": {
6
+ "menu": {
7
+ "settings": "\u8bbe\u7f6e",
8
+ "settingsKey": "S",
9
+ "APIKeys": "API \u5bc6\u94a5",
10
+ "logout": "\u767b\u51fa"
11
+ }
12
+ }
13
+ }
14
+ },
15
+ "molecules": {
16
+ "newChatButton": {
17
+ "newChat": "\u65b0\u5efa\u5bf9\u8bdd"
18
+ },
19
+ "tasklist": {
20
+ "TaskList": {
21
+ "title": "\ud83d\uddd2\ufe0f \u4efb\u52a1\u5217\u8868",
22
+ "loading": "\u52a0\u8f7d\u4e2d...",
23
+ "error": "\u53d1\u751f\u9519\u8bef"
24
+ }
25
+ },
26
+ "attachments": {
27
+ "cancelUpload": "\u53d6\u6d88\u4e0a\u4f20",
28
+ "removeAttachment": "\u79fb\u9664\u9644\u4ef6"
29
+ },
30
+ "newChatDialog": {
31
+ "createNewChat": "\u521b\u5efa\u65b0\u5bf9\u8bdd\uff1f",
32
+ "clearChat": "\u8fd9\u5c06\u6e05\u9664\u5f53\u524d\u6d88\u606f\u5e76\u5f00\u59cb\u65b0\u7684\u5bf9\u8bdd\u3002",
33
+ "cancel": "\u53d6\u6d88",
34
+ "confirm": "\u786e\u8ba4"
35
+ },
36
+ "settingsModal": {
37
+ "settings": "\u8bbe\u7f6e",
38
+ "expandMessages": "\u5c55\u5f00\u6d88\u606f",
39
+ "hideChainOfThought": "\u9690\u85cf\u601d\u8003\u94fe",
40
+ "darkMode": "\u6697\u8272\u6a21\u5f0f"
41
+ },
42
+ "detailsButton": {
43
+ "using": "\u4f7f\u7528",
44
+ "used": "\u5df2\u7528"
45
+ },
46
+ "auth": {
47
+ "authLogin": {
48
+ "title": "\u767b\u5f55\u4ee5\u8bbf\u95ee\u5e94\u7528\u3002",
49
+ "form": {
50
+ "email": "\u7535\u5b50\u90ae\u7bb1\u5730\u5740",
51
+ "password": "\u5bc6\u7801",
52
+ "noAccount": "\u6ca1\u6709\u8d26\u6237\uff1f",
53
+ "alreadyHaveAccount": "\u5df2\u6709\u8d26\u6237\uff1f",
54
+ "signup": "\u6ce8\u518c",
55
+ "signin": "\u767b\u5f55",
56
+ "or": "\u6216\u8005",
57
+ "continue": "\u7ee7\u7eed",
58
+ "forgotPassword": "\u5fd8\u8bb0\u5bc6\u7801\uff1f",
59
+ "passwordMustContain": "\u60a8\u7684\u5bc6\u7801\u5fc5\u987b\u5305\u542b\uff1a",
60
+ "emailRequired": "\u7535\u5b50\u90ae\u7bb1\u662f\u5fc5\u586b\u9879",
61
+ "passwordRequired": "\u5bc6\u7801\u662f\u5fc5\u586b\u9879"
62
+ },
63
+ "error": {
64
+ "default": "\u65e0\u6cd5\u767b\u5f55\u3002",
65
+ "signin": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
66
+ "oauthsignin": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
67
+ "redirect_uri_mismatch": "\u91cd\u5b9a\u5411URI\u4e0eOAuth\u5e94\u7528\u914d\u7f6e\u4e0d\u5339\u914d\u3002",
68
+ "oauthcallbackerror": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
69
+ "oauthcreateaccount": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
70
+ "emailcreateaccount": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
71
+ "callback": "\u5c1d\u8bd5\u4f7f\u7528\u4e0d\u540c\u7684\u8d26\u6237\u767b\u5f55\u3002",
72
+ "oauthaccountnotlinked": "\u4e3a\u4e86\u9a8c\u8bc1\u60a8\u7684\u8eab\u4efd\uff0c\u8bf7\u4f7f\u7528\u6700\u521d\u4f7f\u7528\u7684\u540c\u4e00\u8d26\u6237\u767b\u5f55\u3002",
73
+ "emailsignin": "\u65e0\u6cd5\u53d1\u9001\u90ae\u4ef6\u3002",
74
+ "emailverify": "\u8bf7\u9a8c\u8bc1\u60a8\u7684\u7535\u5b50\u90ae\u4ef6\uff0c\u5df2\u53d1\u9001\u4e00\u5c01\u65b0\u90ae\u4ef6\u3002",
75
+ "credentialssignin": "\u767b\u5f55\u5931\u8d25\u3002\u8bf7\u68c0\u67e5\u60a8\u63d0\u4f9b\u7684\u8be6\u7ec6\u4fe1\u606f\u662f\u5426\u6b63\u786e\u3002",
76
+ "sessionrequired": "\u8bf7\u767b\u5f55\u4ee5\u8bbf\u95ee\u6b64\u9875\u9762\u3002"
77
+ }
78
+ },
79
+ "authVerifyEmail": {
80
+ "almostThere": "\u60a8\u5feb\u6210\u529f\u4e86\uff01\u6211\u4eec\u5df2\u5411 ",
81
+ "verifyEmailLink": "\u8bf7\u5355\u51fb\u8be5\u90ae\u4ef6\u4e2d\u7684\u94fe\u63a5\u4ee5\u5b8c\u6210\u6ce8\u518c\u3002",
82
+ "didNotReceive": "\u6ca1\u627e\u5230\u90ae\u4ef6\uff1f",
83
+ "resendEmail": "\u91cd\u65b0\u53d1\u9001\u90ae\u4ef6",
84
+ "goBack": "\u8fd4\u56de",
85
+ "emailSent": "\u90ae\u4ef6\u5df2\u6210\u529f\u53d1\u9001\u3002",
86
+ "verifyEmail": "\u9a8c\u8bc1\u60a8\u7684\u7535\u5b50\u90ae\u4ef6\u5730\u5740"
87
+ },
88
+ "providerButton": {
89
+ "continue": "\u4f7f\u7528{{provider}}\u7ee7\u7eed",
90
+ "signup": "\u4f7f\u7528{{provider}}\u6ce8\u518c"
91
+ },
92
+ "authResetPassword": {
93
+ "newPasswordRequired": "\u65b0\u5bc6\u7801\u662f\u5fc5\u586b\u9879",
94
+ "passwordsMustMatch": "\u5bc6\u7801\u5fc5\u987b\u4e00\u81f4",
95
+ "confirmPasswordRequired": "\u786e\u8ba4\u5bc6\u7801\u662f\u5fc5\u586b\u9879",
96
+ "newPassword": "\u65b0\u5bc6\u7801",
97
+ "confirmPassword": "\u786e\u8ba4\u5bc6\u7801",
98
+ "resetPassword": "\u91cd\u7f6e\u5bc6\u7801"
99
+ },
100
+ "authForgotPassword": {
101
+ "email": "\u7535\u5b50\u90ae\u7bb1\u5730\u5740",
102
+ "emailRequired": "\u7535\u5b50\u90ae\u7bb1\u662f\u5fc5\u586b\u9879",
103
+ "emailSent": "\u8bf7\u68c0\u67e5\u7535\u5b50\u90ae\u7bb1{{email}}\u4ee5\u83b7\u53d6\u91cd\u7f6e\u5bc6\u7801\u7684\u6307\u793a\u3002",
104
+ "enterEmail": "\u8bf7\u8f93\u5165\u60a8\u7684\u7535\u5b50\u90ae\u7bb1\u5730\u5740\uff0c\u6211\u4eec\u5c06\u53d1\u9001\u91cd\u7f6e\u5bc6\u7801\u7684\u6307\u793a\u3002",
105
+ "resendEmail": "\u91cd\u65b0\u53d1\u9001\u90ae\u4ef6",
106
+ "continue": "\u7ee7\u7eed",
107
+ "goBack": "\u8fd4\u56de"
108
+ }
109
+ }
110
+ },
111
+ "organisms": {
112
+ "chat": {
113
+ "history": {
114
+ "index": {
115
+ "showHistory": "\u663e\u793a\u5386\u53f2",
116
+ "lastInputs": "\u6700\u540e\u8f93\u5165",
117
+ "noInputs": "\u5982\u6b64\u7a7a\u65f7...",
118
+ "loading": "\u52a0\u8f7d\u4e2d..."
119
+ }
120
+ },
121
+ "inputBox": {
122
+ "input": {
123
+ "placeholder": "\u5728\u8fd9\u91cc\u8f93\u5165\u60a8\u7684\u6d88\u606f..."
124
+ },
125
+ "speechButton": {
126
+ "start": "\u5f00\u59cb\u5f55\u97f3",
127
+ "stop": "\u505c\u6b62\u5f55\u97f3"
128
+ },
129
+ "SubmitButton": {
130
+ "sendMessage": "\u53d1\u9001\u6d88\u606f",
131
+ "stopTask": "\u505c\u6b62\u4efb\u52a1"
132
+ },
133
+ "UploadButton": {
134
+ "attachFiles": "\u9644\u52a0\u6587\u4ef6"
135
+ },
136
+ "waterMark": {
137
+ "text": "\u4f7f\u7528"
138
+ }
139
+ },
140
+ "Messages": {
141
+ "index": {
142
+ "running": "\u8fd0\u884c\u4e2d",
143
+ "executedSuccessfully": "\u6267\u884c\u6210\u529f",
144
+ "failed": "\u5931\u8d25",
145
+ "feedbackUpdated": "\u53cd\u9988\u66f4\u65b0",
146
+ "updating": "\u6b63\u5728\u66f4\u65b0"
147
+ }
148
+ },
149
+ "dropScreen": {
150
+ "dropYourFilesHere": "\u5728\u8fd9\u91cc\u62d6\u653e\u60a8\u7684\u6587\u4ef6"
151
+ },
152
+ "index": {
153
+ "failedToUpload": "\u4e0a\u4f20\u5931\u8d25",
154
+ "cancelledUploadOf": "\u53d6\u6d88\u4e0a\u4f20",
155
+ "couldNotReachServer": "\u65e0\u6cd5\u8fde\u63a5\u5230\u670d\u52a1\u5668",
156
+ "continuingChat": "\u7ee7\u7eed\u4e4b\u524d\u7684\u5bf9\u8bdd"
157
+ },
158
+ "settings": {
159
+ "settingsPanel": "\u8bbe\u7f6e\u9762\u677f",
160
+ "reset": "\u91cd\u7f6e",
161
+ "cancel": "\u53d6\u6d88",
162
+ "confirm": "\u786e\u8ba4"
163
+ }
164
+ },
165
+ "threadHistory": {
166
+ "sidebar": {
167
+ "filters": {
168
+ "FeedbackSelect": {
169
+ "feedbackAll": "\u53cd\u9988\uff1a\u5168\u90e8",
170
+ "feedbackPositive": "\u53cd\u9988\uff1a\u6b63\u9762",
171
+ "feedbackNegative": "\u53cd\u9988\uff1a\u8d1f\u9762"
172
+ },
173
+ "SearchBar": {
174
+ "search": "\u641c\u7d22"
175
+ }
176
+ },
177
+ "DeleteThreadButton": {
178
+ "confirmMessage": "\u8fd9\u5c06\u5220\u9664\u7ebf\u7a0b\u53ca\u5176\u6d88\u606f\u548c\u5143\u7d20\u3002",
179
+ "cancel": "\u53d6\u6d88",
180
+ "confirm": "\u786e\u8ba4",
181
+ "deletingChat": "\u5220\u9664\u5bf9\u8bdd",
182
+ "chatDeleted": "\u5bf9\u8bdd\u5df2\u5220\u9664"
183
+ },
184
+ "index": {
185
+ "pastChats": "\u8fc7\u5f80\u5bf9\u8bdd"
186
+ },
187
+ "ThreadList": {
188
+ "empty": "\u7a7a\u7684...",
189
+ "today": "\u4eca\u5929",
190
+ "yesterday": "\u6628\u5929",
191
+ "previous7days": "\u524d7\u5929",
192
+ "previous30days": "\u524d30\u5929"
193
+ },
194
+ "TriggerButton": {
195
+ "closeSidebar": "\u5173\u95ed\u4fa7\u8fb9\u680f",
196
+ "openSidebar": "\u6253\u5f00\u4fa7\u8fb9\u680f"
197
+ }
198
+ },
199
+ "Thread": {
200
+ "backToChat": "\u8fd4\u56de\u5bf9\u8bdd",
201
+ "chatCreatedOn": "\u6b64\u5bf9\u8bdd\u521b\u5efa\u4e8e"
202
+ }
203
+ },
204
+ "header": {
205
+ "chat": "\u5bf9\u8bdd",
206
+ "readme": "\u8bf4\u660e"
207
+ }
208
+ }
209
+ },
210
+ "hooks": {
211
+ "useLLMProviders": {
212
+ "failedToFetchProviders": "\u83b7\u53d6\u63d0\u4f9b\u8005\u5931\u8d25:"
213
+ }
214
+ },
215
+ "pages": {
216
+ "Design": {},
217
+ "Env": {
218
+ "savedSuccessfully": "\u4fdd\u5b58\u6210\u529f",
219
+ "requiredApiKeys": "\u5fc5\u9700\u7684API\u5bc6\u94a5",
220
+ "requiredApiKeysInfo": "\u8981\u4f7f\u7528\u6b64\u5e94\u7528\uff0c\u9700\u8981\u4ee5\u4e0bAPI\u5bc6\u94a5\u3002\u8fd9\u4e9b\u5bc6\u94a5\u5b58\u50a8\u5728\u60a8\u7684\u8bbe\u5907\u672c\u5730\u5b58\u50a8\u4e2d\u3002"
221
+ },
222
+ "Page": {
223
+ "notPartOfProject": "\u60a8\u4e0d\u662f\u6b64\u9879\u76ee\u7684\u4e00\u90e8\u5206\u3002"
224
+ },
225
+ "ResumeButton": {
226
+ "resumeChat": "\u6062\u590d\u5bf9\u8bdd"
227
+ }
228
+ }
229
+ }
@@ -9,8 +9,10 @@ from .base import (
9
9
  from .models import (
10
10
  OpenAIEmbeddings,
11
11
  OpenAIEmbeddingsConfig,
12
- SentenceTransformerEmbeddingsConfig,
13
12
  SentenceTransformerEmbeddings,
13
+ SentenceTransformerEmbeddingsConfig,
14
+ LlamaCppServerEmbeddings,
15
+ LlamaCppServerEmbeddingsConfig,
14
16
  embedding_model,
15
17
  )
16
18
  from .remote_embeds import (
@@ -27,8 +29,10 @@ __all__ = [
27
29
  "EmbeddingModelsConfig",
28
30
  "OpenAIEmbeddings",
29
31
  "OpenAIEmbeddingsConfig",
30
- "SentenceTransformerEmbeddingsConfig",
31
32
  "SentenceTransformerEmbeddings",
33
+ "SentenceTransformerEmbeddingsConfig",
34
+ "LlamaCppServerEmbeddings",
35
+ "LlamaCppServerEmbeddingsConfig",
32
36
  "embedding_model",
33
37
  "RemoteEmbeddingsConfig",
34
38
  "RemoteEmbeddings",
@@ -26,6 +26,8 @@ class EmbeddingModel(ABC):
26
26
  from langroid.embedding_models.models import (
27
27
  FastEmbedEmbeddings,
28
28
  FastEmbedEmbeddingsConfig,
29
+ LlamaCppServerEmbeddings,
30
+ LlamaCppServerEmbeddingsConfig,
29
31
  OpenAIEmbeddings,
30
32
  OpenAIEmbeddingsConfig,
31
33
  SentenceTransformerEmbeddings,
@@ -44,6 +46,8 @@ class EmbeddingModel(ABC):
44
46
  return SentenceTransformerEmbeddings(config)
45
47
  elif isinstance(config, FastEmbedEmbeddingsConfig):
46
48
  return FastEmbedEmbeddings(config)
49
+ elif isinstance(config, LlamaCppServerEmbeddingsConfig):
50
+ return LlamaCppServerEmbeddings(config)
47
51
  else:
48
52
  raise ValueError(f"Unknown embedding config: {config.__repr_name__}")
49
53
 
@@ -3,6 +3,7 @@ import os
3
3
  from functools import cached_property
4
4
  from typing import Any, Callable, Dict, List, Optional
5
5
 
6
+ import requests
6
7
  import tiktoken
7
8
  from dotenv import load_dotenv
8
9
  from openai import OpenAI
@@ -48,13 +49,19 @@ class FastEmbedEmbeddingsConfig(EmbeddingModelsConfig):
48
49
  additional_kwargs: Dict[str, Any] = {}
49
50
 
50
51
 
52
+ class LlamaCppServerEmbeddingsConfig(EmbeddingModelsConfig):
53
+ api_base: str = ""
54
+ context_length: int = 2048
55
+ batch_size: int = 2048
56
+
57
+
51
58
  class EmbeddingFunctionCallable:
52
59
  """
53
60
  A callable class designed to generate embeddings for a list of texts using
54
61
  the OpenAI API, with automatic retries on failure.
55
62
 
56
63
  Attributes:
57
- model (OpenAIEmbeddings): An instance of OpenAIEmbeddings that provides
64
+ embed_model (EmbeddingModel): An instance of EmbeddingModel that provides
58
65
  configuration and utilities for generating embeddings.
59
66
 
60
67
  Methods:
@@ -62,7 +69,7 @@ class EmbeddingFunctionCallable:
62
69
  a list of input texts.
63
70
  """
64
71
 
65
- def __init__(self, model: "OpenAIEmbeddings", batch_size: int = 512):
72
+ def __init__(self, embed_model: EmbeddingModel, batch_size: int = 512):
66
73
  """
67
74
  Initialize the EmbeddingFunctionCallable with a specific model.
68
75
 
@@ -71,7 +78,7 @@ class EmbeddingFunctionCallable:
71
78
  generating embeddings.
72
79
  batch_size (int): Batch size
73
80
  """
74
- self.model = model
81
+ self.embed_model = embed_model
75
82
  self.batch_size = batch_size
76
83
 
77
84
  def __call__(self, input: List[str]) -> Embeddings:
@@ -91,14 +98,46 @@ class EmbeddingFunctionCallable:
91
98
  Returns:
92
99
  Embeddings: A list of embedding vectors corresponding to the input texts.
93
100
  """
94
- tokenized_texts = self.model.truncate_texts(input)
95
101
  embeds = []
96
- for batch in batched(tokenized_texts, self.batch_size):
97
- result = self.model.client.embeddings.create(
98
- input=batch, model=self.model.config.model_name
102
+
103
+ if isinstance(self.embed_model, OpenAIEmbeddings):
104
+ tokenized_texts = self.embed_model.truncate_texts(input)
105
+
106
+ for batch in batched(tokenized_texts, self.batch_size):
107
+ result = self.embed_model.client.embeddings.create(
108
+ input=batch, model=self.embed_model.config.model_name
109
+ )
110
+ batch_embeds = [d.embedding for d in result.data]
111
+ embeds.extend(batch_embeds)
112
+
113
+ elif isinstance(self.embed_model, SentenceTransformerEmbeddings):
114
+ if self.embed_model.config.data_parallel:
115
+ embeds = self.embed_model.model.encode_multi_process(
116
+ input,
117
+ self.embed_model.pool,
118
+ batch_size=self.batch_size,
119
+ ).tolist()
120
+ else:
121
+ for str_batch in batched(input, self.batch_size):
122
+ batch_embeds = self.embed_model.model.encode(
123
+ str_batch, convert_to_numpy=True
124
+ ).tolist() # type: ignore
125
+ embeds.extend(batch_embeds)
126
+
127
+ elif isinstance(self.embed_model, FastEmbedEmbeddings):
128
+ embeddings = self.embed_model.model.embed(
129
+ input, batch_size=self.batch_size, parallel=self.embed_model.parallel
99
130
  )
100
- batch_embeds = [d.embedding for d in result.data]
101
- embeds.extend(batch_embeds)
131
+
132
+ embeds = [embedding.tolist() for embedding in embeddings]
133
+ elif isinstance(self.embed_model, LlamaCppServerEmbeddings):
134
+ for input_string in input:
135
+ tokenized_text = self.embed_model.tokenize_string(input_string)
136
+ for token_batch in batched(tokenized_text, self.batch_size):
137
+ gen_embedding = self.embed_model.generate_embedding(
138
+ self.embed_model.detokenize_string(list(token_batch))
139
+ )
140
+ embeds.append(gen_embedding)
102
141
  return embeds
103
142
 
104
143
 
@@ -176,24 +215,7 @@ class SentenceTransformerEmbeddings(EmbeddingModel):
176
215
  self.config.context_length = self.tokenizer.model_max_length
177
216
 
178
217
  def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
179
- def fn(texts: List[str]) -> Embeddings:
180
- if self.config.data_parallel:
181
- embeds: Embeddings = self.model.encode_multi_process(
182
- texts,
183
- self.pool,
184
- batch_size=self.config.batch_size,
185
- ).tolist()
186
- else:
187
- embeds = []
188
- for batch in batched(texts, self.config.batch_size):
189
- batch_embeds = self.model.encode(
190
- batch, convert_to_numpy=True
191
- ).tolist() # type: ignore
192
- embeds.extend(batch_embeds)
193
-
194
- return embeds
195
-
196
- return fn
218
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
197
219
 
198
220
  @property
199
221
  def embedding_dims(self) -> int:
@@ -214,10 +236,10 @@ class FastEmbedEmbeddings(EmbeddingModel):
214
236
 
215
237
  super().__init__()
216
238
  self.config = config
217
- self._batch_size = config.batch_size
218
- self._parallel = config.parallel
239
+ self.batch_size = config.batch_size
240
+ self.parallel = config.parallel
219
241
 
220
- self._model = TextEmbedding(
242
+ self.model = TextEmbedding(
221
243
  model_name=self.config.model_name,
222
244
  cache_dir=self.config.cache_dir,
223
245
  threads=self.config.threads,
@@ -225,14 +247,7 @@ class FastEmbedEmbeddings(EmbeddingModel):
225
247
  )
226
248
 
227
249
  def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
228
- def fn(texts: List[str]) -> Embeddings:
229
- embeddings = self._model.embed(
230
- texts, batch_size=self._batch_size, parallel=self._parallel
231
- )
232
-
233
- return [embedding.tolist() for embedding in embeddings]
234
-
235
- return fn
250
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
236
251
 
237
252
  @cached_property
238
253
  def embedding_dims(self) -> int:
@@ -240,10 +255,105 @@ class FastEmbedEmbeddings(EmbeddingModel):
240
255
  return len(embed_func(["text"])[0])
241
256
 
242
257
 
258
+ LCSEC = LlamaCppServerEmbeddingsConfig
259
+
260
+
261
+ class LlamaCppServerEmbeddings(EmbeddingModel):
262
+ def __init__(self, config: LCSEC = LCSEC()):
263
+ super().__init__()
264
+ self.config = config
265
+
266
+ if self.config.api_base == "":
267
+ raise ValueError(
268
+ """Api Base MUST be set for Llama Server Embeddings.
269
+ """
270
+ )
271
+
272
+ self.tokenize_url = self.config.api_base + "/tokenize"
273
+ self.detokenize_url = self.config.api_base + "/detokenize"
274
+ self.embedding_url = self.config.api_base + "/embeddings"
275
+
276
+ def tokenize_string(self, text: str) -> List[int]:
277
+ data = {"content": text, "add_special": False, "with_pieces": False}
278
+ response = requests.post(self.tokenize_url, json=data)
279
+
280
+ if response.status_code == 200:
281
+ tokens = response.json()["tokens"]
282
+ if not (isinstance(tokens, list) and isinstance(tokens[0], (int, float))):
283
+ # not all(isinstance(token, (int, float)) for token in tokens):
284
+ raise ValueError(
285
+ """Tokenizer endpoint has not returned the correct format.
286
+ Is the URL correct?
287
+ """
288
+ )
289
+ return tokens
290
+ else:
291
+ raise requests.HTTPError(
292
+ self.tokenize_url,
293
+ response.status_code,
294
+ "Failed to connect to tokenization provider",
295
+ )
296
+
297
+ def detokenize_string(self, tokens: List[int]) -> str:
298
+ data = {"tokens": tokens}
299
+ response = requests.post(self.detokenize_url, json=data)
300
+
301
+ if response.status_code == 200:
302
+ text = response.json()["content"]
303
+ if not isinstance(text, str):
304
+ raise ValueError(
305
+ """Deokenizer endpoint has not returned the correct format.
306
+ Is the URL correct?
307
+ """
308
+ )
309
+ return text
310
+ else:
311
+ raise requests.HTTPError(
312
+ self.detokenize_url,
313
+ response.status_code,
314
+ "Failed to connect to detokenization provider",
315
+ )
316
+
317
+ def truncate_string_to_context_size(self, text: str) -> str:
318
+ tokens = self.tokenize_string(text)
319
+ tokens = tokens[: self.config.context_length]
320
+ return self.detokenize_string(tokens)
321
+
322
+ def generate_embedding(self, text: str) -> List[int | float]:
323
+ data = {"content": text}
324
+ response = requests.post(self.embedding_url, json=data)
325
+
326
+ if response.status_code == 200:
327
+ embeddings = response.json()["embedding"]
328
+ if not (
329
+ isinstance(embeddings, list) and isinstance(embeddings[0], (int, float))
330
+ ):
331
+ raise ValueError(
332
+ """Embedding endpoint has not returned the correct format.
333
+ Is the URL correct?
334
+ """
335
+ )
336
+ return embeddings
337
+ else:
338
+ raise requests.HTTPError(
339
+ self.embedding_url,
340
+ response.status_code,
341
+ "Failed to connect to embedding provider",
342
+ )
343
+
344
+ def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
345
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
346
+
347
+ @property
348
+ def embedding_dims(self) -> int:
349
+ return self.config.dims
350
+
351
+
243
352
  def embedding_model(embedding_fn_type: str = "openai") -> EmbeddingModel:
244
353
  """
245
354
  Args:
246
- embedding_fn_type: "openai" or "sentencetransformer" # others soon
355
+ embedding_fn_type: "openai" or "fastembed" or
356
+ "llamacppserver" or "sentencetransformer" # others soon
247
357
  Returns:
248
358
  EmbeddingModel
249
359
  """
@@ -251,5 +361,7 @@ def embedding_model(embedding_fn_type: str = "openai") -> EmbeddingModel:
251
361
  return OpenAIEmbeddings # type: ignore
252
362
  elif embedding_fn_type == "fastembed":
253
363
  return FastEmbedEmbeddings # type: ignore
364
+ elif embedding_fn_type == "llamacppserver":
365
+ return LlamaCppServerEmbeddings # type: ignore
254
366
  else: # default sentence transformer
255
367
  return SentenceTransformerEmbeddings # type: ignore
@@ -1,3 +1,6 @@
1
+ import logging
2
+ from typing import Callable
3
+
1
4
  from dotenv import load_dotenv
2
5
  from httpx import Timeout
3
6
  from openai import AsyncAzureOpenAI, AzureOpenAI
@@ -15,6 +18,8 @@ azureStructuredOutputList = [
15
18
 
16
19
  azureStructuredOutputAPIMin = "2024-08-01-preview"
17
20
 
21
+ logger = logging.getLogger(__name__)
22
+
18
23
 
19
24
  class AzureConfig(OpenAIGPTConfig):
20
25
  """
@@ -42,6 +47,10 @@ class AzureConfig(OpenAIGPTConfig):
42
47
  model_version: str = "" # is used to determine the cost of using the model
43
48
  api_base: str = ""
44
49
 
50
+ # Alternatively, bring your own clients:
51
+ azure_openai_client_provider: Callable[[], AzureOpenAI] | None = None
52
+ azure_openai_async_client_provider: Callable[[], AsyncAzureOpenAI] | None = None
53
+
45
54
  # all of the vars above can be set via env vars,
46
55
  # by upper-casing the name and prefixing with `env_prefix`, e.g.
47
56
  # AZURE_OPENAI_API_VERSION=2023-05-15
@@ -69,20 +78,6 @@ class AzureGPT(OpenAIGPT):
69
78
  load_dotenv()
70
79
  super().__init__(config)
71
80
  self.config: AzureConfig = config
72
- if self.config.api_key == "":
73
- raise ValueError(
74
- """
75
- AZURE_OPENAI_API_KEY not set in .env file,
76
- please set it to your Azure API key."""
77
- )
78
-
79
- if self.config.api_base == "":
80
- raise ValueError(
81
- """
82
- AZURE_OPENAI_API_BASE not set in .env file,
83
- please set it to your Azure API key."""
84
- )
85
-
86
81
  if self.config.deployment_name == "":
87
82
  raise ValueError(
88
83
  """
@@ -98,6 +93,57 @@ class AzureGPT(OpenAIGPT):
98
93
  please set it to chat model name in your deployment."""
99
94
  )
100
95
 
96
+ if (
97
+ self.config.azure_openai_client_provider
98
+ or self.config.azure_openai_async_client_provider
99
+ ):
100
+ if not self.config.azure_openai_client_provider:
101
+ self.client = None
102
+ logger.warning(
103
+ "Using user-provided Azure OpenAI client, but only async "
104
+ "client has been provided. Synchronous calls will fail."
105
+ )
106
+ if not self.config.azure_openai_async_client_provider:
107
+ self.async_client = None
108
+ logger.warning(
109
+ "Using user-provided Azure OpenAI client, but no async "
110
+ "client has been provided. Asynchronous calls will fail."
111
+ )
112
+
113
+ if self.config.azure_openai_client_provider:
114
+ self.client = self.config.azure_openai_client_provider()
115
+ if self.config.azure_openai_async_client_provider:
116
+ self.async_client = self.config.azure_openai_async_client_provider()
117
+ self.async_client.timeout = Timeout(self.config.timeout)
118
+ else:
119
+ if self.config.api_key == "":
120
+ raise ValueError(
121
+ """
122
+ AZURE_OPENAI_API_KEY not set in .env file,
123
+ please set it to your Azure API key."""
124
+ )
125
+
126
+ if self.config.api_base == "":
127
+ raise ValueError(
128
+ """
129
+ AZURE_OPENAI_API_BASE not set in .env file,
130
+ please set it to your Azure API key."""
131
+ )
132
+
133
+ self.client = AzureOpenAI(
134
+ api_key=self.config.api_key,
135
+ azure_endpoint=self.config.api_base,
136
+ api_version=self.config.api_version,
137
+ azure_deployment=self.config.deployment_name,
138
+ )
139
+ self.async_client = AsyncAzureOpenAI(
140
+ api_key=self.config.api_key,
141
+ azure_endpoint=self.config.api_base,
142
+ api_version=self.config.api_version,
143
+ azure_deployment=self.config.deployment_name,
144
+ timeout=Timeout(self.config.timeout),
145
+ )
146
+
101
147
  # set the chat model to be the same as the model_name
102
148
  # This corresponds to the gpt model you chose for your deployment
103
149
  # when you deployed a model
@@ -108,20 +154,6 @@ class AzureGPT(OpenAIGPT):
108
154
  and self.config.model_version in azureStructuredOutputList
109
155
  )
110
156
 
111
- self.client = AzureOpenAI(
112
- api_key=self.config.api_key,
113
- azure_endpoint=self.config.api_base,
114
- api_version=self.config.api_version,
115
- azure_deployment=self.config.deployment_name,
116
- )
117
- self.async_client = AsyncAzureOpenAI(
118
- api_key=self.config.api_key,
119
- azure_endpoint=self.config.api_base,
120
- api_version=self.config.api_version,
121
- azure_deployment=self.config.deployment_name,
122
- timeout=Timeout(self.config.timeout),
123
- )
124
-
125
157
  def set_chat_model(self) -> None:
126
158
  """
127
159
  Sets the chat model configuration based on the model name specified in the