langroid 0.58.2__py3-none-any.whl → 0.59.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +39 -17
- langroid/agent/base.py-e +2216 -0
- langroid/agent/callbacks/chainlit.py +2 -1
- langroid/agent/chat_agent.py +73 -55
- langroid/agent/chat_agent.py-e +2086 -0
- langroid/agent/chat_document.py +7 -7
- langroid/agent/chat_document.py-e +513 -0
- langroid/agent/openai_assistant.py +9 -9
- langroid/agent/openai_assistant.py-e +882 -0
- langroid/agent/special/arangodb/arangodb_agent.py +10 -18
- langroid/agent/special/arangodb/arangodb_agent.py-e +648 -0
- langroid/agent/special/arangodb/tools.py +3 -3
- langroid/agent/special/doc_chat_agent.py +16 -14
- langroid/agent/special/lance_rag/critic_agent.py +2 -2
- langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
- langroid/agent/special/lance_tools.py +6 -5
- langroid/agent/special/lance_tools.py-e +61 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
- langroid/agent/special/neo4j/neo4j_chat_agent.py-e +430 -0
- langroid/agent/special/relevance_extractor_agent.py +1 -1
- langroid/agent/special/sql/sql_chat_agent.py +11 -3
- langroid/agent/task.py +9 -87
- langroid/agent/task.py-e +2418 -0
- langroid/agent/tool_message.py +33 -17
- langroid/agent/tool_message.py-e +400 -0
- langroid/agent/tools/file_tools.py +4 -2
- langroid/agent/tools/file_tools.py-e +234 -0
- langroid/agent/tools/mcp/fastmcp_client.py +19 -6
- langroid/agent/tools/mcp/fastmcp_client.py-e +584 -0
- langroid/agent/tools/orchestration.py +22 -17
- langroid/agent/tools/orchestration.py-e +301 -0
- langroid/agent/tools/recipient_tool.py +3 -3
- langroid/agent/tools/task_tool.py +22 -16
- langroid/agent/tools/task_tool.py-e +249 -0
- langroid/agent/xml_tool_message.py +90 -35
- langroid/agent/xml_tool_message.py-e +392 -0
- langroid/cachedb/base.py +1 -1
- langroid/embedding_models/base.py +2 -2
- langroid/embedding_models/models.py +3 -7
- langroid/embedding_models/models.py-e +563 -0
- langroid/exceptions.py +4 -1
- langroid/language_models/azure_openai.py +2 -2
- langroid/language_models/azure_openai.py-e +134 -0
- langroid/language_models/base.py +6 -4
- langroid/language_models/base.py-e +812 -0
- langroid/language_models/client_cache.py +64 -0
- langroid/language_models/config.py +2 -4
- langroid/language_models/config.py-e +18 -0
- langroid/language_models/model_info.py +9 -1
- langroid/language_models/model_info.py-e +483 -0
- langroid/language_models/openai_gpt.py +119 -20
- langroid/language_models/openai_gpt.py-e +2280 -0
- langroid/language_models/provider_params.py +3 -22
- langroid/language_models/provider_params.py-e +153 -0
- langroid/mytypes.py +11 -4
- langroid/mytypes.py-e +132 -0
- langroid/parsing/code_parser.py +1 -1
- langroid/parsing/file_attachment.py +1 -1
- langroid/parsing/file_attachment.py-e +246 -0
- langroid/parsing/md_parser.py +14 -4
- langroid/parsing/md_parser.py-e +574 -0
- langroid/parsing/parser.py +22 -7
- langroid/parsing/parser.py-e +410 -0
- langroid/parsing/repo_loader.py +3 -1
- langroid/parsing/repo_loader.py-e +812 -0
- langroid/parsing/search.py +1 -1
- langroid/parsing/url_loader.py +17 -51
- langroid/parsing/url_loader.py-e +683 -0
- langroid/parsing/urls.py +5 -4
- langroid/parsing/urls.py-e +279 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/pydantic_v1/__init__.py +45 -6
- langroid/pydantic_v1/__init__.py-e +36 -0
- langroid/pydantic_v1/main.py +11 -4
- langroid/pydantic_v1/main.py-e +11 -0
- langroid/utils/configuration.py +13 -11
- langroid/utils/configuration.py-e +141 -0
- langroid/utils/constants.py +1 -1
- langroid/utils/constants.py-e +32 -0
- langroid/utils/globals.py +21 -5
- langroid/utils/globals.py-e +49 -0
- langroid/utils/html_logger.py +2 -1
- langroid/utils/html_logger.py-e +825 -0
- langroid/utils/object_registry.py +1 -1
- langroid/utils/object_registry.py-e +66 -0
- langroid/utils/pydantic_utils.py +55 -28
- langroid/utils/pydantic_utils.py-e +602 -0
- langroid/utils/types.py +2 -2
- langroid/utils/types.py-e +113 -0
- langroid/vector_store/base.py +3 -3
- langroid/vector_store/lancedb.py +5 -5
- langroid/vector_store/lancedb.py-e +404 -0
- langroid/vector_store/meilisearch.py +2 -2
- langroid/vector_store/pineconedb.py +4 -4
- langroid/vector_store/pineconedb.py-e +427 -0
- langroid/vector_store/postgres.py +1 -1
- langroid/vector_store/qdrantdb.py +3 -3
- langroid/vector_store/weaviatedb.py +1 -1
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/METADATA +3 -2
- langroid-0.59.0b1.dist-info/RECORD +181 -0
- langroid/agent/special/doc_chat_task.py +0 -0
- langroid/mcp/__init__.py +0 -1
- langroid/mcp/server/__init__.py +0 -1
- langroid-0.58.2.dist-info/RECORD +0 -145
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/WHEEL +0 -0
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,563 @@
|
|
1
|
+
import atexit
|
2
|
+
import os
|
3
|
+
from functools import cached_property
|
4
|
+
from typing import Any, Callable, Dict, List, Optional
|
5
|
+
|
6
|
+
import requests
|
7
|
+
import tiktoken
|
8
|
+
from dotenv import load_dotenv
|
9
|
+
from openai import AzureOpenAI, OpenAI
|
10
|
+
|
11
|
+
from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
|
12
|
+
from langroid.exceptions import LangroidImportError
|
13
|
+
from langroid.language_models.provider_params import LangDBParams
|
14
|
+
from langroid.mytypes import Embeddings
|
15
|
+
from langroid.parsing.utils import batched
|
16
|
+
from pydantic import ConfigDict
|
17
|
+
|
18
|
+
AzureADTokenProvider = Callable[[], str]
|
19
|
+
|
20
|
+
|
21
|
+
class OpenAIEmbeddingsConfig(EmbeddingModelsConfig):
|
22
|
+
model_type: str = "openai"
|
23
|
+
model_name: str = "text-embedding-3-small"
|
24
|
+
api_key: str = ""
|
25
|
+
api_base: Optional[str] = None
|
26
|
+
organization: str = ""
|
27
|
+
dims: int = 1536
|
28
|
+
context_length: int = 8192
|
29
|
+
langdb_params: LangDBParams = LangDBParams()
|
30
|
+
|
31
|
+
model_config = ConfigDict(env_prefix="OPENAI_")
|
32
|
+
|
33
|
+
class AzureOpenAIEmbeddingsConfig(EmbeddingModelsConfig):
|
34
|
+
model_type: str = "azure-openai"
|
35
|
+
model_name: str = "text-embedding-3-small"
|
36
|
+
api_key: str = ""
|
37
|
+
api_base: str = ""
|
38
|
+
deployment_name: Optional[str] = None
|
39
|
+
# api_version defaulted to 2024-06-01 as per https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=python-new
|
40
|
+
# change this to required supported version
|
41
|
+
api_version: Optional[str] = "2024-06-01"
|
42
|
+
# TODO: Add auth support for Azure OpenAI via AzureADTokenProvider
|
43
|
+
azure_ad_token: Optional[str] = None
|
44
|
+
azure_ad_token_provider: Optional[AzureADTokenProvider] = None
|
45
|
+
dims: int = 1536
|
46
|
+
context_length: int = 8192
|
47
|
+
|
48
|
+
model_config = ConfigDict(env_prefix="AZURE_OPENAI_")
|
49
|
+
|
50
|
+
class SentenceTransformerEmbeddingsConfig(EmbeddingModelsConfig):
|
51
|
+
model_type: str = "sentence-transformer"
|
52
|
+
model_name: str = "BAAI/bge-large-en-v1.5"
|
53
|
+
context_length: int = 512
|
54
|
+
data_parallel: bool = False
|
55
|
+
# Select device (e.g. "cuda", "cpu") when data parallel is disabled
|
56
|
+
device: Optional[str] = None
|
57
|
+
# Select devices when data parallel is enabled
|
58
|
+
devices: Optional[list[str]] = None
|
59
|
+
|
60
|
+
|
61
|
+
class FastEmbedEmbeddingsConfig(EmbeddingModelsConfig):
|
62
|
+
"""Config for qdrant/fastembed embeddings,
|
63
|
+
see here: https://github.com/qdrant/fastembed
|
64
|
+
"""
|
65
|
+
|
66
|
+
model_type: str = "fastembed"
|
67
|
+
model_name: str = "BAAI/bge-small-en-v1.5"
|
68
|
+
batch_size: int = 256
|
69
|
+
cache_dir: Optional[str] = None
|
70
|
+
threads: Optional[int] = None
|
71
|
+
parallel: Optional[int] = None
|
72
|
+
additional_kwargs: Dict[str, Any] = {}
|
73
|
+
|
74
|
+
|
75
|
+
class LlamaCppServerEmbeddingsConfig(EmbeddingModelsConfig):
|
76
|
+
api_base: str = ""
|
77
|
+
context_length: int = 2048
|
78
|
+
batch_size: int = 2048
|
79
|
+
|
80
|
+
|
81
|
+
class GeminiEmbeddingsConfig(EmbeddingModelsConfig):
|
82
|
+
model_type: str = "gemini"
|
83
|
+
model_name: str = "models/text-embedding-004"
|
84
|
+
api_key: str = ""
|
85
|
+
dims: int = 768
|
86
|
+
batch_size: int = 512
|
87
|
+
|
88
|
+
|
89
|
+
class EmbeddingFunctionCallable:
|
90
|
+
"""
|
91
|
+
A callable class designed to generate embeddings for a list of texts using
|
92
|
+
the OpenAI or Azure OpenAI API, with automatic retries on failure.
|
93
|
+
|
94
|
+
Attributes:
|
95
|
+
embed_model (EmbeddingModel): An instance of EmbeddingModel that provides
|
96
|
+
configuration and utilities for generating embeddings.
|
97
|
+
|
98
|
+
Methods:
|
99
|
+
__call__(input: List[str]) -> Embeddings: Generate embeddings for
|
100
|
+
a list of input texts.
|
101
|
+
"""
|
102
|
+
|
103
|
+
def __init__(self, embed_model: EmbeddingModel, batch_size: int = 512):
|
104
|
+
"""
|
105
|
+
Initialize the EmbeddingFunctionCallable with a specific model.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
model ( OpenAIEmbeddings or AzureOpenAIEmbeddings): An instance of
|
109
|
+
OpenAIEmbeddings or AzureOpenAIEmbeddings to use for
|
110
|
+
generating embeddings.
|
111
|
+
batch_size (int): Batch size
|
112
|
+
"""
|
113
|
+
self.embed_model = embed_model
|
114
|
+
self.batch_size = batch_size
|
115
|
+
|
116
|
+
def __call__(self, input: List[str]) -> Embeddings:
|
117
|
+
"""
|
118
|
+
Generate embeddings for a given list of input texts using the OpenAI API,
|
119
|
+
with retries on failure.
|
120
|
+
|
121
|
+
This method:
|
122
|
+
- Truncates each text in the input list to the model's maximum context length.
|
123
|
+
- Processes the texts in batches to generate embeddings efficiently.
|
124
|
+
- Automatically retries the embedding generation process with exponential
|
125
|
+
backoff in case of failures.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
input (List[str]): A list of input texts to generate embeddings for.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
Embeddings: A list of embedding vectors corresponding to the input texts.
|
132
|
+
"""
|
133
|
+
embeds = []
|
134
|
+
if isinstance(self.embed_model, (OpenAIEmbeddings, AzureOpenAIEmbeddings)):
|
135
|
+
# Truncate texts to context length while preserving text format
|
136
|
+
truncated_texts = self.embed_model.truncate_texts(input)
|
137
|
+
|
138
|
+
# Process in batches
|
139
|
+
for batch in batched(truncated_texts, self.batch_size):
|
140
|
+
result = self.embed_model.client.embeddings.create(
|
141
|
+
input=batch, model=self.embed_model.config.model_name # type: ignore
|
142
|
+
)
|
143
|
+
batch_embeds = [d.embedding for d in result.data]
|
144
|
+
embeds.extend(batch_embeds)
|
145
|
+
|
146
|
+
elif isinstance(self.embed_model, SentenceTransformerEmbeddings):
|
147
|
+
if self.embed_model.config.data_parallel:
|
148
|
+
embeds = self.embed_model.model.encode_multi_process(
|
149
|
+
input,
|
150
|
+
self.embed_model.pool,
|
151
|
+
batch_size=self.batch_size,
|
152
|
+
).tolist()
|
153
|
+
else:
|
154
|
+
for str_batch in batched(input, self.batch_size):
|
155
|
+
batch_embeds = self.embed_model.model.encode(
|
156
|
+
str_batch, convert_to_numpy=True
|
157
|
+
).tolist() # type: ignore
|
158
|
+
embeds.extend(batch_embeds)
|
159
|
+
|
160
|
+
elif isinstance(self.embed_model, FastEmbedEmbeddings):
|
161
|
+
embeddings = self.embed_model.model.embed(
|
162
|
+
input, batch_size=self.batch_size, parallel=self.embed_model.parallel
|
163
|
+
)
|
164
|
+
|
165
|
+
embeds = [embedding.tolist() for embedding in embeddings]
|
166
|
+
elif isinstance(self.embed_model, LlamaCppServerEmbeddings):
|
167
|
+
for input_string in input:
|
168
|
+
tokenized_text = self.embed_model.tokenize_string(input_string)
|
169
|
+
for token_batch in batched(tokenized_text, self.batch_size):
|
170
|
+
gen_embedding = self.embed_model.generate_embedding(
|
171
|
+
self.embed_model.detokenize_string(list(token_batch))
|
172
|
+
)
|
173
|
+
embeds.append(gen_embedding)
|
174
|
+
elif isinstance(self.embed_model, GeminiEmbeddings):
|
175
|
+
embeds = self.embed_model.generate_embeddings(input)
|
176
|
+
return embeds
|
177
|
+
|
178
|
+
|
179
|
+
class OpenAIEmbeddings(EmbeddingModel):
|
180
|
+
def __init__(self, config: OpenAIEmbeddingsConfig = OpenAIEmbeddingsConfig()):
|
181
|
+
super().__init__()
|
182
|
+
self.config = config
|
183
|
+
load_dotenv()
|
184
|
+
|
185
|
+
# Check if using LangDB
|
186
|
+
self.is_langdb = self.config.model_name.startswith("langdb/")
|
187
|
+
|
188
|
+
if self.is_langdb:
|
189
|
+
self.config.model_name = self.config.model_name.replace("langdb/", "")
|
190
|
+
self.config.api_base = self.config.langdb_params.base_url
|
191
|
+
project_id = self.config.langdb_params.project_id
|
192
|
+
if project_id:
|
193
|
+
self.config.api_base += "/" + project_id + "/v1"
|
194
|
+
self.config.api_key = self.config.langdb_params.api_key
|
195
|
+
|
196
|
+
if not self.config.api_key:
|
197
|
+
self.config.api_key = os.getenv("OPENAI_API_KEY", "")
|
198
|
+
|
199
|
+
self.config.organization = os.getenv("OPENAI_ORGANIZATION", "")
|
200
|
+
|
201
|
+
if self.config.api_key == "":
|
202
|
+
if self.is_langdb:
|
203
|
+
raise ValueError(
|
204
|
+
"""
|
205
|
+
LANGDB_API_KEY must be set in .env or your environment
|
206
|
+
to use OpenAIEmbeddings via LangDB.
|
207
|
+
"""
|
208
|
+
)
|
209
|
+
else:
|
210
|
+
raise ValueError(
|
211
|
+
"""
|
212
|
+
OPENAI_API_KEY must be set in .env or your environment
|
213
|
+
to use OpenAIEmbeddings.
|
214
|
+
"""
|
215
|
+
)
|
216
|
+
|
217
|
+
self.client = OpenAI(
|
218
|
+
base_url=self.config.api_base,
|
219
|
+
api_key=self.config.api_key,
|
220
|
+
organization=self.config.organization,
|
221
|
+
)
|
222
|
+
model_for_tokenizer = self.config.model_name
|
223
|
+
if model_for_tokenizer.startswith("openai/"):
|
224
|
+
self.config.model_name = model_for_tokenizer.replace("openai/", "")
|
225
|
+
self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
|
226
|
+
|
227
|
+
def truncate_texts(self, texts: List[str]) -> List[str] | List[List[int]]:
|
228
|
+
"""
|
229
|
+
Truncate texts to the embedding model's context length.
|
230
|
+
TODO: Maybe we should show warning, and consider doing T5 summarization?
|
231
|
+
"""
|
232
|
+
truncated_tokens = [
|
233
|
+
self.tokenizer.encode(text, disallowed_special=())[
|
234
|
+
: self.config.context_length
|
235
|
+
]
|
236
|
+
for text in texts
|
237
|
+
]
|
238
|
+
|
239
|
+
if self.is_langdb:
|
240
|
+
# LangDB embedding endpt only works with strings, not tokens
|
241
|
+
return [self.tokenizer.decode(tokens) for tokens in truncated_tokens]
|
242
|
+
return truncated_tokens
|
243
|
+
|
244
|
+
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
245
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
246
|
+
|
247
|
+
@property
|
248
|
+
def embedding_dims(self) -> int:
|
249
|
+
return self.config.dims
|
250
|
+
|
251
|
+
|
252
|
+
class AzureOpenAIEmbeddings(EmbeddingModel):
|
253
|
+
"""
|
254
|
+
Azure OpenAI embeddings model implementation.
|
255
|
+
"""
|
256
|
+
|
257
|
+
def __init__(
|
258
|
+
self, config: AzureOpenAIEmbeddingsConfig = AzureOpenAIEmbeddingsConfig()
|
259
|
+
):
|
260
|
+
"""
|
261
|
+
Initializes Azure OpenAI embeddings model.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
config: Configuration for Azure OpenAI embeddings model.
|
265
|
+
Raises:
|
266
|
+
ValueError: If required Azure config values are not set.
|
267
|
+
"""
|
268
|
+
super().__init__()
|
269
|
+
self.config = config
|
270
|
+
load_dotenv()
|
271
|
+
|
272
|
+
if self.config.api_key == "":
|
273
|
+
raise ValueError(
|
274
|
+
"""AZURE_OPENAI_API_KEY env variable must be set to use
|
275
|
+
AzureOpenAIEmbeddings. Please set the AZURE_OPENAI_API_KEY value
|
276
|
+
in your .env file."""
|
277
|
+
)
|
278
|
+
|
279
|
+
if self.config.api_base == "":
|
280
|
+
raise ValueError(
|
281
|
+
"""AZURE_OPENAI_API_BASE env variable must be set to use
|
282
|
+
AzureOpenAIEmbeddings. Please set the AZURE_OPENAI_API_BASE value
|
283
|
+
in your .env file."""
|
284
|
+
)
|
285
|
+
self.client = AzureOpenAI(
|
286
|
+
api_key=self.config.api_key,
|
287
|
+
api_version=self.config.api_version,
|
288
|
+
azure_endpoint=self.config.api_base,
|
289
|
+
azure_deployment=self.config.deployment_name,
|
290
|
+
)
|
291
|
+
self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
|
292
|
+
|
293
|
+
def truncate_texts(self, texts: List[str]) -> List[str] | List[List[int]]:
|
294
|
+
"""
|
295
|
+
Truncate texts to the embedding model's context length.
|
296
|
+
TODO: Maybe we should show warning, and consider doing T5 summarization?
|
297
|
+
"""
|
298
|
+
return [
|
299
|
+
self.tokenizer.encode(text, disallowed_special=())[
|
300
|
+
: self.config.context_length
|
301
|
+
]
|
302
|
+
for text in texts
|
303
|
+
]
|
304
|
+
|
305
|
+
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
306
|
+
"""Get the embedding function for Azure OpenAI.
|
307
|
+
|
308
|
+
Returns:
|
309
|
+
Callable that generates embeddings for input texts.
|
310
|
+
"""
|
311
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
312
|
+
|
313
|
+
@property
|
314
|
+
def embedding_dims(self) -> int:
|
315
|
+
return self.config.dims
|
316
|
+
|
317
|
+
|
318
|
+
STEC = SentenceTransformerEmbeddingsConfig
|
319
|
+
|
320
|
+
|
321
|
+
class SentenceTransformerEmbeddings(EmbeddingModel):
|
322
|
+
def __init__(self, config: STEC = STEC()):
|
323
|
+
# this is an "extra" optional dependency, so we import it here
|
324
|
+
try:
|
325
|
+
from sentence_transformers import SentenceTransformer
|
326
|
+
from transformers import AutoTokenizer
|
327
|
+
except ImportError:
|
328
|
+
raise ImportError(
|
329
|
+
"""
|
330
|
+
To use sentence_transformers embeddings,
|
331
|
+
you must install langroid with the [hf-embeddings] extra, e.g.:
|
332
|
+
pip install "langroid[hf-embeddings]"
|
333
|
+
"""
|
334
|
+
)
|
335
|
+
|
336
|
+
super().__init__()
|
337
|
+
self.config = config
|
338
|
+
|
339
|
+
self.model = SentenceTransformer(
|
340
|
+
self.config.model_name,
|
341
|
+
device=self.config.device,
|
342
|
+
)
|
343
|
+
if self.config.data_parallel:
|
344
|
+
self.pool = self.model.start_multi_process_pool(
|
345
|
+
self.config.devices # type: ignore
|
346
|
+
)
|
347
|
+
atexit.register(
|
348
|
+
lambda: SentenceTransformer.stop_multi_process_pool(self.pool)
|
349
|
+
)
|
350
|
+
|
351
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
352
|
+
self.config.context_length = self.tokenizer.model_max_length
|
353
|
+
|
354
|
+
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
355
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
356
|
+
|
357
|
+
@property
|
358
|
+
def embedding_dims(self) -> int:
|
359
|
+
dims = self.model.get_sentence_embedding_dimension()
|
360
|
+
if dims is None:
|
361
|
+
raise ValueError(
|
362
|
+
f"Could not get embedding dimension for model {self.config.model_name}"
|
363
|
+
)
|
364
|
+
return dims # type: ignore
|
365
|
+
|
366
|
+
|
367
|
+
class FastEmbedEmbeddings(EmbeddingModel):
|
368
|
+
def __init__(self, config: FastEmbedEmbeddingsConfig = FastEmbedEmbeddingsConfig()):
|
369
|
+
try:
|
370
|
+
from fastembed import TextEmbedding
|
371
|
+
except ImportError:
|
372
|
+
raise LangroidImportError("fastembed", extra="fastembed")
|
373
|
+
|
374
|
+
super().__init__()
|
375
|
+
self.config = config
|
376
|
+
self.batch_size = config.batch_size
|
377
|
+
self.parallel = config.parallel
|
378
|
+
|
379
|
+
self.model = TextEmbedding(
|
380
|
+
model_name=self.config.model_name,
|
381
|
+
cache_dir=self.config.cache_dir,
|
382
|
+
threads=self.config.threads,
|
383
|
+
**self.config.additional_kwargs,
|
384
|
+
)
|
385
|
+
|
386
|
+
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
387
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
388
|
+
|
389
|
+
@cached_property
|
390
|
+
def embedding_dims(self) -> int:
|
391
|
+
embed_func = self.embedding_fn()
|
392
|
+
return len(embed_func(["text"])[0])
|
393
|
+
|
394
|
+
|
395
|
+
LCSEC = LlamaCppServerEmbeddingsConfig
|
396
|
+
|
397
|
+
|
398
|
+
class LlamaCppServerEmbeddings(EmbeddingModel):
|
399
|
+
def __init__(self, config: LCSEC = LCSEC()):
|
400
|
+
super().__init__()
|
401
|
+
self.config = config
|
402
|
+
|
403
|
+
if self.config.api_base == "":
|
404
|
+
raise ValueError(
|
405
|
+
"""Api Base MUST be set for Llama Server Embeddings.
|
406
|
+
"""
|
407
|
+
)
|
408
|
+
|
409
|
+
self.tokenize_url = self.config.api_base + "/tokenize"
|
410
|
+
self.detokenize_url = self.config.api_base + "/detokenize"
|
411
|
+
self.embedding_url = self.config.api_base + "/embeddings"
|
412
|
+
|
413
|
+
def tokenize_string(self, text: str) -> List[int]:
|
414
|
+
data = {"content": text, "add_special": False, "with_pieces": False}
|
415
|
+
response = requests.post(self.tokenize_url, json=data)
|
416
|
+
|
417
|
+
if response.status_code == 200:
|
418
|
+
tokens = response.model_dump_json()["tokens"]
|
419
|
+
if not (isinstance(tokens, list) and isinstance(tokens[0], (int, float))):
|
420
|
+
# not all(isinstance(token, (int, float)) for token in tokens):
|
421
|
+
raise ValueError(
|
422
|
+
"""Tokenizer endpoint has not returned the correct format.
|
423
|
+
Is the URL correct?
|
424
|
+
"""
|
425
|
+
)
|
426
|
+
return tokens
|
427
|
+
else:
|
428
|
+
raise requests.HTTPError(
|
429
|
+
self.tokenize_url,
|
430
|
+
response.status_code,
|
431
|
+
"Failed to connect to tokenization provider",
|
432
|
+
)
|
433
|
+
|
434
|
+
def detokenize_string(self, tokens: List[int]) -> str:
|
435
|
+
data = {"tokens": tokens}
|
436
|
+
response = requests.post(self.detokenize_url, json=data)
|
437
|
+
|
438
|
+
if response.status_code == 200:
|
439
|
+
text = response.model_dump_json()["content"]
|
440
|
+
if not isinstance(text, str):
|
441
|
+
raise ValueError(
|
442
|
+
"""Deokenizer endpoint has not returned the correct format.
|
443
|
+
Is the URL correct?
|
444
|
+
"""
|
445
|
+
)
|
446
|
+
return text
|
447
|
+
else:
|
448
|
+
raise requests.HTTPError(
|
449
|
+
self.detokenize_url,
|
450
|
+
response.status_code,
|
451
|
+
"Failed to connect to detokenization provider",
|
452
|
+
)
|
453
|
+
|
454
|
+
def truncate_string_to_context_size(self, text: str) -> str:
|
455
|
+
tokens = self.tokenize_string(text)
|
456
|
+
tokens = tokens[: self.config.context_length]
|
457
|
+
return self.detokenize_string(tokens)
|
458
|
+
|
459
|
+
def generate_embedding(self, text: str) -> List[int | float]:
|
460
|
+
data = {"content": text}
|
461
|
+
response = requests.post(self.embedding_url, json=data)
|
462
|
+
|
463
|
+
if response.status_code == 200:
|
464
|
+
embeddings = response.model_dump_json()["embedding"]
|
465
|
+
if not (
|
466
|
+
isinstance(embeddings, list) and isinstance(embeddings[0], (int, float))
|
467
|
+
):
|
468
|
+
raise ValueError(
|
469
|
+
"""Embedding endpoint has not returned the correct format.
|
470
|
+
Is the URL correct?
|
471
|
+
"""
|
472
|
+
)
|
473
|
+
return embeddings
|
474
|
+
else:
|
475
|
+
raise requests.HTTPError(
|
476
|
+
self.embedding_url,
|
477
|
+
response.status_code,
|
478
|
+
"Failed to connect to embedding provider",
|
479
|
+
)
|
480
|
+
|
481
|
+
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
482
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
483
|
+
|
484
|
+
@property
|
485
|
+
def embedding_dims(self) -> int:
|
486
|
+
return self.config.dims
|
487
|
+
|
488
|
+
|
489
|
+
class GeminiEmbeddings(EmbeddingModel):
|
490
|
+
def __init__(self, config: GeminiEmbeddingsConfig = GeminiEmbeddingsConfig()):
|
491
|
+
try:
|
492
|
+
from google import genai
|
493
|
+
except ImportError as e:
|
494
|
+
raise LangroidImportError(extra="google-genai", error=str(e))
|
495
|
+
super().__init__()
|
496
|
+
self.config = config
|
497
|
+
load_dotenv()
|
498
|
+
self.config.api_key = os.getenv("GEMINI_API_KEY", "")
|
499
|
+
|
500
|
+
if self.config.api_key == "":
|
501
|
+
raise ValueError(
|
502
|
+
"""
|
503
|
+
GEMINI_API_KEY env variable must be set to use GeminiEmbeddings.
|
504
|
+
"""
|
505
|
+
)
|
506
|
+
self.client = genai.Client(api_key=self.config.api_key)
|
507
|
+
|
508
|
+
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
509
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
510
|
+
|
511
|
+
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
512
|
+
"""Generates embeddings for a list of input texts."""
|
513
|
+
all_embeddings: List[List[float]] = []
|
514
|
+
|
515
|
+
for batch in batched(texts, self.config.batch_size):
|
516
|
+
result = self.client.models.embed_content( # type: ignore[attr-defined]
|
517
|
+
model=self.config.model_name,
|
518
|
+
contents=batch, # type: ignore
|
519
|
+
)
|
520
|
+
|
521
|
+
if not hasattr(result, "embeddings") or not isinstance(
|
522
|
+
result.embeddings, list
|
523
|
+
):
|
524
|
+
raise ValueError(
|
525
|
+
"Unexpected format for embeddings: missing or incorrect type"
|
526
|
+
)
|
527
|
+
|
528
|
+
# Extract .values from ContentEmbedding objects
|
529
|
+
all_embeddings.extend(
|
530
|
+
[emb.values for emb in result.embeddings] # type: ignore
|
531
|
+
)
|
532
|
+
|
533
|
+
return all_embeddings
|
534
|
+
|
535
|
+
@property
|
536
|
+
def embedding_dims(self) -> int:
|
537
|
+
return self.config.dims
|
538
|
+
|
539
|
+
|
540
|
+
def embedding_model(embedding_fn_type: str = "openai") -> EmbeddingModel:
|
541
|
+
"""
|
542
|
+
Args:
|
543
|
+
embedding_fn_type: Type of embedding model to use. Options are:
|
544
|
+
- "openai",
|
545
|
+
- "azure-openai",
|
546
|
+
- "sentencetransformer", or
|
547
|
+
- "fastembed".
|
548
|
+
(others may be added in the future)
|
549
|
+
Returns:
|
550
|
+
EmbeddingModel: The corresponding embedding model class.
|
551
|
+
"""
|
552
|
+
if embedding_fn_type == "openai":
|
553
|
+
return OpenAIEmbeddings # type: ignore
|
554
|
+
elif embedding_fn_type == "azure-openai":
|
555
|
+
return AzureOpenAIEmbeddings # type: ignore
|
556
|
+
elif embedding_fn_type == "fastembed":
|
557
|
+
return FastEmbedEmbeddings # type: ignore
|
558
|
+
elif embedding_fn_type == "llamacppserver":
|
559
|
+
return LlamaCppServerEmbeddings # type: ignore
|
560
|
+
elif embedding_fn_type == "gemini":
|
561
|
+
return GeminiEmbeddings # type: ignore
|
562
|
+
else: # default sentence transformer
|
563
|
+
return SentenceTransformerEmbeddings # type: ignore
|
langroid/exceptions.py
CHANGED
@@ -4,6 +4,7 @@ from typing import Callable
|
|
4
4
|
from dotenv import load_dotenv
|
5
5
|
from httpx import Timeout
|
6
6
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
7
|
+
from pydantic_settings import SettingsConfigDict
|
7
8
|
|
8
9
|
from langroid.language_models.openai_gpt import (
|
9
10
|
OpenAIGPT,
|
@@ -56,8 +57,7 @@ class AzureConfig(OpenAIGPTConfig):
|
|
56
57
|
# AZURE_OPENAI_API_VERSION=2023-05-15
|
57
58
|
# This is either done in the .env file, or via an explicit
|
58
59
|
# `export AZURE_OPENAI_API_VERSION=...`
|
59
|
-
|
60
|
-
env_prefix = "AZURE_OPENAI_"
|
60
|
+
model_config = SettingsConfigDict(env_prefix="AZURE_OPENAI_")
|
61
61
|
|
62
62
|
def __init__(self, **kwargs) -> None: # type: ignore
|
63
63
|
if "model_name" in kwargs and "chat_model" not in kwargs:
|