langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. langroid/__init__.py +106 -0
  2. langroid/agent/__init__.py +41 -0
  3. langroid/agent/base.py +1983 -0
  4. langroid/agent/batch.py +398 -0
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +598 -0
  7. langroid/agent/chat_agent.py +1899 -0
  8. langroid/agent/chat_document.py +454 -0
  9. langroid/agent/openai_assistant.py +882 -0
  10. langroid/agent/special/__init__.py +59 -0
  11. langroid/agent/special/arangodb/__init__.py +0 -0
  12. langroid/agent/special/arangodb/arangodb_agent.py +656 -0
  13. langroid/agent/special/arangodb/system_messages.py +186 -0
  14. langroid/agent/special/arangodb/tools.py +107 -0
  15. langroid/agent/special/arangodb/utils.py +36 -0
  16. langroid/agent/special/doc_chat_agent.py +1466 -0
  17. langroid/agent/special/lance_doc_chat_agent.py +262 -0
  18. langroid/agent/special/lance_rag/__init__.py +9 -0
  19. langroid/agent/special/lance_rag/critic_agent.py +198 -0
  20. langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
  21. langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
  22. langroid/agent/special/lance_tools.py +61 -0
  23. langroid/agent/special/neo4j/__init__.py +0 -0
  24. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  25. langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
  26. langroid/agent/special/neo4j/system_messages.py +120 -0
  27. langroid/agent/special/neo4j/tools.py +32 -0
  28. langroid/agent/special/relevance_extractor_agent.py +127 -0
  29. langroid/agent/special/retriever_agent.py +56 -0
  30. langroid/agent/special/sql/__init__.py +17 -0
  31. langroid/agent/special/sql/sql_chat_agent.py +654 -0
  32. langroid/agent/special/sql/utils/__init__.py +21 -0
  33. langroid/agent/special/sql/utils/description_extractors.py +190 -0
  34. langroid/agent/special/sql/utils/populate_metadata.py +85 -0
  35. langroid/agent/special/sql/utils/system_message.py +35 -0
  36. langroid/agent/special/sql/utils/tools.py +64 -0
  37. langroid/agent/special/table_chat_agent.py +263 -0
  38. langroid/agent/task.py +2095 -0
  39. langroid/agent/tool_message.py +393 -0
  40. langroid/agent/tools/__init__.py +38 -0
  41. langroid/agent/tools/duckduckgo_search_tool.py +50 -0
  42. langroid/agent/tools/file_tools.py +234 -0
  43. langroid/agent/tools/google_search_tool.py +39 -0
  44. langroid/agent/tools/metaphor_search_tool.py +68 -0
  45. langroid/agent/tools/orchestration.py +303 -0
  46. langroid/agent/tools/recipient_tool.py +235 -0
  47. langroid/agent/tools/retrieval_tool.py +32 -0
  48. langroid/agent/tools/rewind_tool.py +137 -0
  49. langroid/agent/tools/segment_extract_tool.py +41 -0
  50. langroid/agent/xml_tool_message.py +382 -0
  51. langroid/cachedb/__init__.py +17 -0
  52. langroid/cachedb/base.py +58 -0
  53. langroid/cachedb/momento_cachedb.py +108 -0
  54. langroid/cachedb/redis_cachedb.py +153 -0
  55. langroid/embedding_models/__init__.py +39 -0
  56. langroid/embedding_models/base.py +74 -0
  57. langroid/embedding_models/models.py +461 -0
  58. langroid/embedding_models/protoc/__init__.py +0 -0
  59. langroid/embedding_models/protoc/embeddings.proto +19 -0
  60. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  61. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  62. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  63. langroid/embedding_models/remote_embeds.py +153 -0
  64. langroid/exceptions.py +71 -0
  65. langroid/language_models/__init__.py +53 -0
  66. langroid/language_models/azure_openai.py +153 -0
  67. langroid/language_models/base.py +678 -0
  68. langroid/language_models/config.py +18 -0
  69. langroid/language_models/mock_lm.py +124 -0
  70. langroid/language_models/openai_gpt.py +1964 -0
  71. langroid/language_models/prompt_formatter/__init__.py +16 -0
  72. langroid/language_models/prompt_formatter/base.py +40 -0
  73. langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
  74. langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
  75. langroid/language_models/utils.py +151 -0
  76. langroid/mytypes.py +84 -0
  77. langroid/parsing/__init__.py +52 -0
  78. langroid/parsing/agent_chats.py +38 -0
  79. langroid/parsing/code_parser.py +121 -0
  80. langroid/parsing/document_parser.py +718 -0
  81. langroid/parsing/para_sentence_split.py +62 -0
  82. langroid/parsing/parse_json.py +155 -0
  83. langroid/parsing/parser.py +313 -0
  84. langroid/parsing/repo_loader.py +790 -0
  85. langroid/parsing/routing.py +36 -0
  86. langroid/parsing/search.py +275 -0
  87. langroid/parsing/spider.py +102 -0
  88. langroid/parsing/table_loader.py +94 -0
  89. langroid/parsing/url_loader.py +111 -0
  90. langroid/parsing/urls.py +273 -0
  91. langroid/parsing/utils.py +373 -0
  92. langroid/parsing/web_search.py +156 -0
  93. langroid/prompts/__init__.py +9 -0
  94. langroid/prompts/dialog.py +17 -0
  95. langroid/prompts/prompts_config.py +5 -0
  96. langroid/prompts/templates.py +141 -0
  97. langroid/pydantic_v1/__init__.py +10 -0
  98. langroid/pydantic_v1/main.py +4 -0
  99. langroid/utils/__init__.py +19 -0
  100. langroid/utils/algorithms/__init__.py +3 -0
  101. langroid/utils/algorithms/graph.py +103 -0
  102. langroid/utils/configuration.py +98 -0
  103. langroid/utils/constants.py +30 -0
  104. langroid/utils/git_utils.py +252 -0
  105. langroid/utils/globals.py +49 -0
  106. langroid/utils/logging.py +135 -0
  107. langroid/utils/object_registry.py +66 -0
  108. langroid/utils/output/__init__.py +20 -0
  109. langroid/utils/output/citations.py +41 -0
  110. langroid/utils/output/printing.py +99 -0
  111. langroid/utils/output/status.py +40 -0
  112. langroid/utils/pandas_utils.py +30 -0
  113. langroid/utils/pydantic_utils.py +602 -0
  114. langroid/utils/system.py +286 -0
  115. langroid/utils/types.py +93 -0
  116. langroid/vector_store/__init__.py +50 -0
  117. langroid/vector_store/base.py +359 -0
  118. langroid/vector_store/chromadb.py +214 -0
  119. langroid/vector_store/lancedb.py +406 -0
  120. langroid/vector_store/meilisearch.py +299 -0
  121. langroid/vector_store/momento.py +278 -0
  122. langroid/vector_store/qdrantdb.py +468 -0
  123. {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
  124. langroid-0.33.7.dist-info/RECORD +127 -0
  125. {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
  126. langroid-0.33.6.dist-info/RECORD +0 -7
  127. langroid-0.33.6.dist-info/entry_points.txt +0 -4
  128. pyproject.toml +0 -356
  129. {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,461 @@
1
+ import atexit
2
+ import os
3
+ from functools import cached_property
4
+ from typing import Any, Callable, Dict, List, Optional
5
+
6
+ import requests
7
+ import tiktoken
8
+ from dotenv import load_dotenv
9
+ from openai import AzureOpenAI, OpenAI
10
+
11
+ from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
12
+ from langroid.exceptions import LangroidImportError
13
+ from langroid.mytypes import Embeddings
14
+ from langroid.parsing.utils import batched
15
+
16
+ AzureADTokenProvider = Callable[[], str]
17
+
18
+
19
+ class OpenAIEmbeddingsConfig(EmbeddingModelsConfig):
20
+ model_type: str = "openai"
21
+ model_name: str = "text-embedding-ada-002"
22
+ api_key: str = ""
23
+ api_base: Optional[str] = None
24
+ organization: str = ""
25
+ dims: int = 1536
26
+ context_length: int = 8192
27
+
28
+
29
+ class AzureOpenAIEmbeddingsConfig(EmbeddingModelsConfig):
30
+ model_type: str = "azure-openai"
31
+ model_name: str = "text-embedding-ada-002"
32
+ api_key: str = ""
33
+ api_base: str = ""
34
+ deployment_name: Optional[str] = None
35
+ # api_version defaulted to 2024-06-01 as per https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=python-new
36
+ # change this to required supported version
37
+ api_version: Optional[str] = "2024-06-01"
38
+ # TODO: Add auth support for Azure OpenAI via AzureADTokenProvider
39
+ azure_ad_token: Optional[str] = None
40
+ azure_ad_token_provider: Optional[AzureADTokenProvider] = None
41
+ dims: int = 1536
42
+ context_length: int = 8192
43
+
44
+ class Config:
45
+ # enable auto-loading of env vars with AZURE_OPENAI_ prefix
46
+ env_prefix = "AZURE_OPENAI_"
47
+
48
+
49
+ class SentenceTransformerEmbeddingsConfig(EmbeddingModelsConfig):
50
+ model_type: str = "sentence-transformer"
51
+ model_name: str = "BAAI/bge-large-en-v1.5"
52
+ context_length: int = 512
53
+ data_parallel: bool = False
54
+ # Select device (e.g. "cuda", "cpu") when data parallel is disabled
55
+ device: Optional[str] = None
56
+ # Select devices when data parallel is enabled
57
+ devices: Optional[list[str]] = None
58
+
59
+
60
+ class FastEmbedEmbeddingsConfig(EmbeddingModelsConfig):
61
+ """Config for qdrant/fastembed embeddings,
62
+ see here: https://github.com/qdrant/fastembed
63
+ """
64
+
65
+ model_type: str = "fastembed"
66
+ model_name: str = "BAAI/bge-small-en-v1.5"
67
+ batch_size: int = 256
68
+ cache_dir: Optional[str] = None
69
+ threads: Optional[int] = None
70
+ parallel: Optional[int] = None
71
+ additional_kwargs: Dict[str, Any] = {}
72
+
73
+
74
+ class LlamaCppServerEmbeddingsConfig(EmbeddingModelsConfig):
75
+ api_base: str = ""
76
+ context_length: int = 2048
77
+ batch_size: int = 2048
78
+
79
+
80
+ class EmbeddingFunctionCallable:
81
+ """
82
+ A callable class designed to generate embeddings for a list of texts using
83
+ the OpenAI or Azure OpenAI API, with automatic retries on failure.
84
+
85
+ Attributes:
86
+ embed_model (EmbeddingModel): An instance of EmbeddingModel that provides
87
+ configuration and utilities for generating embeddings.
88
+
89
+ Methods:
90
+ __call__(input: List[str]) -> Embeddings: Generate embeddings for
91
+ a list of input texts.
92
+ """
93
+
94
+ def __init__(self, embed_model: EmbeddingModel, batch_size: int = 512):
95
+ """
96
+ Initialize the EmbeddingFunctionCallable with a specific model.
97
+
98
+ Args:
99
+ model ( OpenAIEmbeddings or AzureOpenAIEmbeddings): An instance of
100
+ OpenAIEmbeddings or AzureOpenAIEmbeddings to use for
101
+ generating embeddings.
102
+ batch_size (int): Batch size
103
+ """
104
+ self.embed_model = embed_model
105
+ self.batch_size = batch_size
106
+
107
+ def __call__(self, input: List[str]) -> Embeddings:
108
+ """
109
+ Generate embeddings for a given list of input texts using the OpenAI API,
110
+ with retries on failure.
111
+
112
+ This method:
113
+ - Truncates each text in the input list to the model's maximum context length.
114
+ - Processes the texts in batches to generate embeddings efficiently.
115
+ - Automatically retries the embedding generation process with exponential
116
+ backoff in case of failures.
117
+
118
+ Args:
119
+ input (List[str]): A list of input texts to generate embeddings for.
120
+
121
+ Returns:
122
+ Embeddings: A list of embedding vectors corresponding to the input texts.
123
+ """
124
+ embeds = []
125
+ if isinstance(self.embed_model, (OpenAIEmbeddings, AzureOpenAIEmbeddings)):
126
+ tokenized_texts = self.embed_model.truncate_texts(input)
127
+
128
+ for batch in batched(tokenized_texts, self.batch_size):
129
+ result = self.embed_model.client.embeddings.create(
130
+ input=batch, model=self.embed_model.config.model_name
131
+ )
132
+ batch_embeds = [d.embedding for d in result.data]
133
+ embeds.extend(batch_embeds)
134
+
135
+ elif isinstance(self.embed_model, SentenceTransformerEmbeddings):
136
+ if self.embed_model.config.data_parallel:
137
+ embeds = self.embed_model.model.encode_multi_process(
138
+ input,
139
+ self.embed_model.pool,
140
+ batch_size=self.batch_size,
141
+ ).tolist()
142
+ else:
143
+ for str_batch in batched(input, self.batch_size):
144
+ batch_embeds = self.embed_model.model.encode(
145
+ str_batch, convert_to_numpy=True
146
+ ).tolist() # type: ignore
147
+ embeds.extend(batch_embeds)
148
+
149
+ elif isinstance(self.embed_model, FastEmbedEmbeddings):
150
+ embeddings = self.embed_model.model.embed(
151
+ input, batch_size=self.batch_size, parallel=self.embed_model.parallel
152
+ )
153
+
154
+ embeds = [embedding.tolist() for embedding in embeddings]
155
+ elif isinstance(self.embed_model, LlamaCppServerEmbeddings):
156
+ for input_string in input:
157
+ tokenized_text = self.embed_model.tokenize_string(input_string)
158
+ for token_batch in batched(tokenized_text, self.batch_size):
159
+ gen_embedding = self.embed_model.generate_embedding(
160
+ self.embed_model.detokenize_string(list(token_batch))
161
+ )
162
+ embeds.append(gen_embedding)
163
+ return embeds
164
+
165
+
166
+ class OpenAIEmbeddings(EmbeddingModel):
167
+ def __init__(self, config: OpenAIEmbeddingsConfig = OpenAIEmbeddingsConfig()):
168
+ super().__init__()
169
+ self.config = config
170
+ load_dotenv()
171
+ self.config.api_key = os.getenv("OPENAI_API_KEY", "")
172
+ self.config.organization = os.getenv("OPENAI_ORGANIZATION", "")
173
+ if self.config.api_key == "":
174
+ raise ValueError(
175
+ """OPENAI_API_KEY env variable must be set to use
176
+ OpenAIEmbeddings. Please set the OPENAI_API_KEY value
177
+ in your .env file.
178
+ """
179
+ )
180
+ self.client = OpenAI(base_url=self.config.api_base, api_key=self.config.api_key)
181
+ self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
182
+
183
+ def truncate_texts(self, texts: List[str]) -> List[List[int]]:
184
+ """
185
+ Truncate texts to the embedding model's context length.
186
+ TODO: Maybe we should show warning, and consider doing T5 summarization?
187
+ """
188
+ return [
189
+ self.tokenizer.encode(text, disallowed_special=())[
190
+ : self.config.context_length
191
+ ]
192
+ for text in texts
193
+ ]
194
+
195
+ def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
196
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
197
+
198
+ @property
199
+ def embedding_dims(self) -> int:
200
+ return self.config.dims
201
+
202
+
203
+ class AzureOpenAIEmbeddings(EmbeddingModel):
204
+ """
205
+ Azure OpenAI embeddings model implementation.
206
+ """
207
+
208
+ def __init__(
209
+ self, config: AzureOpenAIEmbeddingsConfig = AzureOpenAIEmbeddingsConfig()
210
+ ):
211
+ """
212
+ Initializes Azure OpenAI embeddings model.
213
+
214
+ Args:
215
+ config: Configuration for Azure OpenAI embeddings model.
216
+ Raises:
217
+ ValueError: If required Azure config values are not set.
218
+ """
219
+ super().__init__()
220
+ self.config = config
221
+ load_dotenv()
222
+
223
+ if self.config.api_key == "":
224
+ raise ValueError(
225
+ """AZURE_OPENAI_API_KEY env variable must be set to use
226
+ AzureOpenAIEmbeddings. Please set the AZURE_OPENAI_API_KEY value
227
+ in your .env file."""
228
+ )
229
+
230
+ if self.config.api_base == "":
231
+ raise ValueError(
232
+ """AZURE_OPENAI_API_BASE env variable must be set to use
233
+ AzureOpenAIEmbeddings. Please set the AZURE_OPENAI_API_BASE value
234
+ in your .env file."""
235
+ )
236
+ self.client = AzureOpenAI(
237
+ api_key=self.config.api_key,
238
+ api_version=self.config.api_version,
239
+ azure_endpoint=self.config.api_base,
240
+ azure_deployment=self.config.deployment_name,
241
+ )
242
+ self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
243
+
244
+ def truncate_texts(self, texts: List[str]) -> List[List[int]]:
245
+ """
246
+ Truncate texts to the embedding model's context length.
247
+ TODO: Maybe we should show warning, and consider doing T5 summarization?
248
+ """
249
+ return [
250
+ self.tokenizer.encode(text, disallowed_special=())[
251
+ : self.config.context_length
252
+ ]
253
+ for text in texts
254
+ ]
255
+
256
+ def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
257
+ """Get the embedding function for Azure OpenAI.
258
+
259
+ Returns:
260
+ Callable that generates embeddings for input texts.
261
+ """
262
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
263
+
264
+ @property
265
+ def embedding_dims(self) -> int:
266
+ return self.config.dims
267
+
268
+
269
+ STEC = SentenceTransformerEmbeddingsConfig
270
+
271
+
272
+ class SentenceTransformerEmbeddings(EmbeddingModel):
273
+ def __init__(self, config: STEC = STEC()):
274
+ # this is an "extra" optional dependency, so we import it here
275
+ try:
276
+ from sentence_transformers import SentenceTransformer
277
+ from transformers import AutoTokenizer
278
+ except ImportError:
279
+ raise ImportError(
280
+ """
281
+ To use sentence_transformers embeddings,
282
+ you must install langroid with the [hf-embeddings] extra, e.g.:
283
+ pip install "langroid[hf-embeddings]"
284
+ """
285
+ )
286
+
287
+ super().__init__()
288
+ self.config = config
289
+
290
+ self.model = SentenceTransformer(
291
+ self.config.model_name,
292
+ device=self.config.device,
293
+ )
294
+ if self.config.data_parallel:
295
+ self.pool = self.model.start_multi_process_pool(
296
+ self.config.devices # type: ignore
297
+ )
298
+ atexit.register(
299
+ lambda: SentenceTransformer.stop_multi_process_pool(self.pool)
300
+ )
301
+
302
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
303
+ self.config.context_length = self.tokenizer.model_max_length
304
+
305
+ def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
306
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
307
+
308
+ @property
309
+ def embedding_dims(self) -> int:
310
+ dims = self.model.get_sentence_embedding_dimension()
311
+ if dims is None:
312
+ raise ValueError(
313
+ f"Could not get embedding dimension for model {self.config.model_name}"
314
+ )
315
+ return dims # type: ignore
316
+
317
+
318
+ class FastEmbedEmbeddings(EmbeddingModel):
319
+ def __init__(self, config: FastEmbedEmbeddingsConfig = FastEmbedEmbeddingsConfig()):
320
+ try:
321
+ from fastembed import TextEmbedding
322
+ except ImportError:
323
+ raise LangroidImportError("fastembed", extra="fastembed")
324
+
325
+ super().__init__()
326
+ self.config = config
327
+ self.batch_size = config.batch_size
328
+ self.parallel = config.parallel
329
+
330
+ self.model = TextEmbedding(
331
+ model_name=self.config.model_name,
332
+ cache_dir=self.config.cache_dir,
333
+ threads=self.config.threads,
334
+ **self.config.additional_kwargs,
335
+ )
336
+
337
+ def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
338
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
339
+
340
+ @cached_property
341
+ def embedding_dims(self) -> int:
342
+ embed_func = self.embedding_fn()
343
+ return len(embed_func(["text"])[0])
344
+
345
+
346
+ LCSEC = LlamaCppServerEmbeddingsConfig
347
+
348
+
349
+ class LlamaCppServerEmbeddings(EmbeddingModel):
350
+ def __init__(self, config: LCSEC = LCSEC()):
351
+ super().__init__()
352
+ self.config = config
353
+
354
+ if self.config.api_base == "":
355
+ raise ValueError(
356
+ """Api Base MUST be set for Llama Server Embeddings.
357
+ """
358
+ )
359
+
360
+ self.tokenize_url = self.config.api_base + "/tokenize"
361
+ self.detokenize_url = self.config.api_base + "/detokenize"
362
+ self.embedding_url = self.config.api_base + "/embeddings"
363
+
364
+ def tokenize_string(self, text: str) -> List[int]:
365
+ data = {"content": text, "add_special": False, "with_pieces": False}
366
+ response = requests.post(self.tokenize_url, json=data)
367
+
368
+ if response.status_code == 200:
369
+ tokens = response.json()["tokens"]
370
+ if not (isinstance(tokens, list) and isinstance(tokens[0], (int, float))):
371
+ # not all(isinstance(token, (int, float)) for token in tokens):
372
+ raise ValueError(
373
+ """Tokenizer endpoint has not returned the correct format.
374
+ Is the URL correct?
375
+ """
376
+ )
377
+ return tokens
378
+ else:
379
+ raise requests.HTTPError(
380
+ self.tokenize_url,
381
+ response.status_code,
382
+ "Failed to connect to tokenization provider",
383
+ )
384
+
385
+ def detokenize_string(self, tokens: List[int]) -> str:
386
+ data = {"tokens": tokens}
387
+ response = requests.post(self.detokenize_url, json=data)
388
+
389
+ if response.status_code == 200:
390
+ text = response.json()["content"]
391
+ if not isinstance(text, str):
392
+ raise ValueError(
393
+ """Deokenizer endpoint has not returned the correct format.
394
+ Is the URL correct?
395
+ """
396
+ )
397
+ return text
398
+ else:
399
+ raise requests.HTTPError(
400
+ self.detokenize_url,
401
+ response.status_code,
402
+ "Failed to connect to detokenization provider",
403
+ )
404
+
405
+ def truncate_string_to_context_size(self, text: str) -> str:
406
+ tokens = self.tokenize_string(text)
407
+ tokens = tokens[: self.config.context_length]
408
+ return self.detokenize_string(tokens)
409
+
410
+ def generate_embedding(self, text: str) -> List[int | float]:
411
+ data = {"content": text}
412
+ response = requests.post(self.embedding_url, json=data)
413
+
414
+ if response.status_code == 200:
415
+ embeddings = response.json()["embedding"]
416
+ if not (
417
+ isinstance(embeddings, list) and isinstance(embeddings[0], (int, float))
418
+ ):
419
+ raise ValueError(
420
+ """Embedding endpoint has not returned the correct format.
421
+ Is the URL correct?
422
+ """
423
+ )
424
+ return embeddings
425
+ else:
426
+ raise requests.HTTPError(
427
+ self.embedding_url,
428
+ response.status_code,
429
+ "Failed to connect to embedding provider",
430
+ )
431
+
432
+ def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
433
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
434
+
435
+ @property
436
+ def embedding_dims(self) -> int:
437
+ return self.config.dims
438
+
439
+
440
+ def embedding_model(embedding_fn_type: str = "openai") -> EmbeddingModel:
441
+ """
442
+ Args:
443
+ embedding_fn_type: Type of embedding model to use. Options are:
444
+ - "openai",
445
+ - "azure-openai",
446
+ - "sentencetransformer", or
447
+ - "fastembed".
448
+ (others may be added in the future)
449
+ Returns:
450
+ EmbeddingModel: The corresponding embedding model class.
451
+ """
452
+ if embedding_fn_type == "openai":
453
+ return OpenAIEmbeddings # type: ignore
454
+ elif embedding_fn_type == "azure-openai":
455
+ return AzureOpenAIEmbeddings # type: ignore
456
+ elif embedding_fn_type == "fastembed":
457
+ return FastEmbedEmbeddings # type: ignore
458
+ elif embedding_fn_type == "llamacppserver":
459
+ return LlamaCppServerEmbeddings # type: ignore
460
+ else: # default sentence transformer
461
+ return SentenceTransformerEmbeddings # type: ignore
File without changes
@@ -0,0 +1,19 @@
1
+ syntax = "proto3";
2
+
3
+ service Embedding {
4
+ rpc Embed (EmbeddingRequest) returns (BatchEmbeds) {};
5
+ }
6
+
7
+ message EmbeddingRequest {
8
+ string model_name = 1;
9
+ int32 batch_size = 2;
10
+ repeated string strings = 3;
11
+ }
12
+
13
+ message BatchEmbeds {
14
+ repeated Embed embeds = 1;
15
+ }
16
+
17
+ message Embed {
18
+ repeated float embed = 1;
19
+ }
@@ -0,0 +1,33 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: embeddings.proto
4
+ # Protobuf Python Version: 4.25.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+
11
+ # @@protoc_insertion_point(imports)
12
+
13
+ _sym_db = _symbol_database.Default()
14
+
15
+
16
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
+ b'\n\x10\x65mbeddings.proto"K\n\x10\x45mbeddingRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x02 \x01(\x05\x12\x0f\n\x07strings\x18\x03 \x03(\t"%\n\x0b\x42\x61tchEmbeds\x12\x16\n\x06\x65mbeds\x18\x01 \x03(\x0b\x32\x06.Embed"\x16\n\x05\x45mbed\x12\r\n\x05\x65mbed\x18\x01 \x03(\x02\x32\x37\n\tEmbedding\x12*\n\x05\x45mbed\x12\x11.EmbeddingRequest\x1a\x0c.BatchEmbeds"\x00\x62\x06proto3'
18
+ )
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embeddings_pb2", _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals["_EMBEDDINGREQUEST"]._serialized_start = 20
26
+ _globals["_EMBEDDINGREQUEST"]._serialized_end = 95
27
+ _globals["_BATCHEMBEDS"]._serialized_start = 97
28
+ _globals["_BATCHEMBEDS"]._serialized_end = 134
29
+ _globals["_EMBED"]._serialized_start = 136
30
+ _globals["_EMBED"]._serialized_end = 158
31
+ _globals["_EMBEDDING"]._serialized_start = 160
32
+ _globals["_EMBEDDING"]._serialized_end = 215
33
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,50 @@
1
+ from typing import (
2
+ ClassVar as _ClassVar,
3
+ )
4
+ from typing import (
5
+ Iterable as _Iterable,
6
+ )
7
+ from typing import (
8
+ Mapping as _Mapping,
9
+ )
10
+ from typing import (
11
+ Optional as _Optional,
12
+ )
13
+ from typing import (
14
+ Union as _Union,
15
+ )
16
+
17
+ from google.protobuf import descriptor as _descriptor
18
+ from google.protobuf import message as _message
19
+ from google.protobuf.internal import containers as _containers
20
+
21
+ DESCRIPTOR: _descriptor.FileDescriptor
22
+
23
+ class EmbeddingRequest(_message.Message):
24
+ __slots__ = ("model_name", "batch_size", "strings")
25
+ MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
26
+ BATCH_SIZE_FIELD_NUMBER: _ClassVar[int]
27
+ STRINGS_FIELD_NUMBER: _ClassVar[int]
28
+ model_name: str
29
+ batch_size: int
30
+ strings: _containers.RepeatedScalarFieldContainer[str]
31
+ def __init__(
32
+ self,
33
+ model_name: _Optional[str] = ...,
34
+ batch_size: _Optional[int] = ...,
35
+ strings: _Optional[_Iterable[str]] = ...,
36
+ ) -> None: ...
37
+
38
+ class BatchEmbeds(_message.Message):
39
+ __slots__ = ("embeds",)
40
+ EMBEDS_FIELD_NUMBER: _ClassVar[int]
41
+ embeds: _containers.RepeatedCompositeFieldContainer[Embed]
42
+ def __init__(
43
+ self, embeds: _Optional[_Iterable[_Union[Embed, _Mapping]]] = ...
44
+ ) -> None: ...
45
+
46
+ class Embed(_message.Message):
47
+ __slots__ = ("embed",)
48
+ EMBED_FIELD_NUMBER: _ClassVar[int]
49
+ embed: _containers.RepeatedScalarFieldContainer[float]
50
+ def __init__(self, embed: _Optional[_Iterable[float]] = ...) -> None: ...
@@ -0,0 +1,79 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
5
+ import langroid.embedding_models.protoc.embeddings_pb2 as embeddings__pb2
6
+
7
+
8
+ class EmbeddingStub(object):
9
+ """Missing associated documentation comment in .proto file."""
10
+
11
+ def __init__(self, channel):
12
+ """Constructor.
13
+
14
+ Args:
15
+ channel: A grpc.Channel.
16
+ """
17
+ self.Embed = channel.unary_unary(
18
+ "/Embedding/Embed",
19
+ request_serializer=embeddings__pb2.EmbeddingRequest.SerializeToString,
20
+ response_deserializer=embeddings__pb2.BatchEmbeds.FromString,
21
+ )
22
+
23
+
24
+ class EmbeddingServicer(object):
25
+ """Missing associated documentation comment in .proto file."""
26
+
27
+ def Embed(self, request, context):
28
+ """Missing associated documentation comment in .proto file."""
29
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
30
+ context.set_details("Method not implemented!")
31
+ raise NotImplementedError("Method not implemented!")
32
+
33
+
34
+ def add_EmbeddingServicer_to_server(servicer, server):
35
+ rpc_method_handlers = {
36
+ "Embed": grpc.unary_unary_rpc_method_handler(
37
+ servicer.Embed,
38
+ request_deserializer=embeddings__pb2.EmbeddingRequest.FromString,
39
+ response_serializer=embeddings__pb2.BatchEmbeds.SerializeToString,
40
+ ),
41
+ }
42
+ generic_handler = grpc.method_handlers_generic_handler(
43
+ "Embedding", rpc_method_handlers
44
+ )
45
+ server.add_generic_rpc_handlers((generic_handler,))
46
+
47
+
48
+ # This class is part of an EXPERIMENTAL API.
49
+ class Embedding(object):
50
+ """Missing associated documentation comment in .proto file."""
51
+
52
+ @staticmethod
53
+ def Embed(
54
+ request,
55
+ target,
56
+ options=(),
57
+ channel_credentials=None,
58
+ call_credentials=None,
59
+ insecure=False,
60
+ compression=None,
61
+ wait_for_ready=None,
62
+ timeout=None,
63
+ metadata=None,
64
+ ):
65
+ return grpc.experimental.unary_unary(
66
+ request,
67
+ target,
68
+ "/Embedding/Embed",
69
+ embeddings__pb2.EmbeddingRequest.SerializeToString,
70
+ embeddings__pb2.BatchEmbeds.FromString,
71
+ options,
72
+ channel_credentials,
73
+ insecure,
74
+ call_credentials,
75
+ compression,
76
+ wait_for_ready,
77
+ timeout,
78
+ metadata,
79
+ )