langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -1,7 +1,8 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
-
from
|
4
|
+
from contextlib import AbstractContextManager, contextmanager
|
5
|
+
from typing import Any, Dict, List, TypeVar
|
5
6
|
|
6
7
|
import fakeredis
|
7
8
|
import redis
|
@@ -10,6 +11,7 @@ from pydantic import BaseModel
|
|
10
11
|
|
11
12
|
from langroid.cachedb.base import CacheDB
|
12
13
|
|
14
|
+
T = TypeVar("T", bound="RedisCache")
|
13
15
|
logger = logging.getLogger(__name__)
|
14
16
|
|
15
17
|
|
@@ -33,7 +35,7 @@ class RedisCache(CacheDB):
|
|
33
35
|
load_dotenv()
|
34
36
|
|
35
37
|
if self.config.fake:
|
36
|
-
self.
|
38
|
+
self.pool = fakeredis.FakeStrictRedis() # type: ignore
|
37
39
|
else:
|
38
40
|
redis_password = os.getenv("REDIS_PASSWORD")
|
39
41
|
redis_host = os.getenv("REDIS_HOST")
|
@@ -43,21 +45,46 @@ class RedisCache(CacheDB):
|
|
43
45
|
"""REDIS_PASSWORD, REDIS_HOST, REDIS_PORT not set in .env file,
|
44
46
|
using fake redis client"""
|
45
47
|
)
|
46
|
-
self.
|
48
|
+
self.pool = fakeredis.FakeStrictRedis() # type: ignore
|
47
49
|
else:
|
48
|
-
self.
|
50
|
+
self.pool = redis.ConnectionPool( # type: ignore
|
49
51
|
host=redis_host,
|
50
52
|
port=redis_port,
|
51
53
|
password=redis_password,
|
54
|
+
max_connections=50,
|
55
|
+
socket_timeout=5,
|
56
|
+
socket_keepalive=True,
|
57
|
+
retry_on_timeout=True,
|
58
|
+
health_check_interval=30,
|
52
59
|
)
|
53
60
|
|
61
|
+
@contextmanager # type: ignore
|
62
|
+
def redis_client(self) -> AbstractContextManager[T]: # type: ignore
|
63
|
+
"""Cleanly open and close a redis client, avoids max clients exceeded error"""
|
64
|
+
if isinstance(self.pool, fakeredis.FakeStrictRedis):
|
65
|
+
yield self.pool
|
66
|
+
else:
|
67
|
+
client: T = redis.Redis(connection_pool=self.pool)
|
68
|
+
try:
|
69
|
+
yield client
|
70
|
+
finally:
|
71
|
+
client.close()
|
72
|
+
|
73
|
+
def close_all_connections(self) -> None:
|
74
|
+
with self.redis_client() as client: # type: ignore
|
75
|
+
clients = client.client_list()
|
76
|
+
for c in clients:
|
77
|
+
client.client_kill(c["addr"])
|
78
|
+
|
54
79
|
def clear(self) -> None:
|
55
80
|
"""Clear keys from current db."""
|
56
|
-
self.
|
81
|
+
with self.redis_client() as client: # type: ignore
|
82
|
+
client.flushdb()
|
57
83
|
|
58
84
|
def clear_all(self) -> None:
|
59
85
|
"""Clear all keys from all dbs."""
|
60
|
-
self.
|
86
|
+
with self.redis_client() as client: # type: ignore
|
87
|
+
client.flushall()
|
61
88
|
|
62
89
|
def store(self, key: str, value: Any) -> None:
|
63
90
|
"""
|
@@ -67,9 +94,14 @@ class RedisCache(CacheDB):
|
|
67
94
|
key (str): The key under which to store the value.
|
68
95
|
value (Any): The value to store.
|
69
96
|
"""
|
70
|
-
self.
|
71
|
-
|
72
|
-
|
97
|
+
with self.redis_client() as client: # type: ignore
|
98
|
+
try:
|
99
|
+
client.set(key, json.dumps(value))
|
100
|
+
except redis.exceptions.ConnectionError:
|
101
|
+
logger.warning("Redis connection error, not storing key/value")
|
102
|
+
return None
|
103
|
+
|
104
|
+
def retrieve(self, key: str) -> Dict[str, Any] | str | None:
|
73
105
|
"""
|
74
106
|
Retrieve the value associated with a key.
|
75
107
|
|
@@ -79,5 +111,40 @@ class RedisCache(CacheDB):
|
|
79
111
|
Returns:
|
80
112
|
dict: The value associated with the key.
|
81
113
|
"""
|
82
|
-
|
83
|
-
|
114
|
+
with self.redis_client() as client: # type: ignore
|
115
|
+
try:
|
116
|
+
value = client.get(key)
|
117
|
+
except redis.exceptions.ConnectionError:
|
118
|
+
logger.warning("Redis connection error, returning None")
|
119
|
+
return None
|
120
|
+
return json.loads(value) if value else None
|
121
|
+
|
122
|
+
def delete_keys(self, keys: List[str]) -> None:
|
123
|
+
"""
|
124
|
+
Delete the keys from the cache.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
keys (List[str]): The keys to delete.
|
128
|
+
"""
|
129
|
+
with self.redis_client() as client: # type: ignore
|
130
|
+
try:
|
131
|
+
client.delete(*keys)
|
132
|
+
except redis.exceptions.ConnectionError:
|
133
|
+
logger.warning("Redis connection error, not deleting keys")
|
134
|
+
return None
|
135
|
+
|
136
|
+
def delete_keys_pattern(self, pattern: str) -> None:
|
137
|
+
"""
|
138
|
+
Delete the keys matching the pattern from the cache.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
prefix (str): The pattern to match.
|
142
|
+
"""
|
143
|
+
with self.redis_client() as client: # type: ignore
|
144
|
+
try:
|
145
|
+
keys = client.keys(pattern)
|
146
|
+
if len(keys) > 0:
|
147
|
+
client.delete(*keys)
|
148
|
+
except redis.exceptions.ConnectionError:
|
149
|
+
logger.warning("Redis connection error, not deleting keys")
|
150
|
+
return None
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from . import base
|
2
|
+
from . import models
|
3
|
+
from . import remote_embeds
|
4
|
+
|
5
|
+
from .base import (
|
6
|
+
EmbeddingModel,
|
7
|
+
EmbeddingModelsConfig,
|
8
|
+
)
|
9
|
+
from .models import (
|
10
|
+
OpenAIEmbeddings,
|
11
|
+
OpenAIEmbeddingsConfig,
|
12
|
+
SentenceTransformerEmbeddingsConfig,
|
13
|
+
SentenceTransformerEmbeddings,
|
14
|
+
embedding_model,
|
15
|
+
)
|
16
|
+
from .remote_embeds import (
|
17
|
+
RemoteEmbeddingsConfig,
|
18
|
+
RemoteEmbeddings,
|
19
|
+
)
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
"base",
|
23
|
+
"models",
|
24
|
+
"remote_embeds",
|
25
|
+
"EmbeddingModel",
|
26
|
+
"EmbeddingModelsConfig",
|
27
|
+
"OpenAIEmbeddings",
|
28
|
+
"OpenAIEmbeddingsConfig",
|
29
|
+
"SentenceTransformerEmbeddingsConfig",
|
30
|
+
"SentenceTransformerEmbeddings",
|
31
|
+
"embedding_model",
|
32
|
+
"RemoteEmbeddingsConfig",
|
33
|
+
"RemoteEmbeddings",
|
34
|
+
]
|
@@ -1,15 +1,19 @@
|
|
1
1
|
import logging
|
2
2
|
from abc import ABC, abstractmethod
|
3
3
|
|
4
|
-
|
4
|
+
import numpy as np
|
5
5
|
from pydantic import BaseSettings
|
6
6
|
|
7
|
+
from langroid.mytypes import EmbeddingFunction
|
8
|
+
|
7
9
|
logging.getLogger("openai").setLevel(logging.ERROR)
|
8
10
|
|
9
11
|
|
10
12
|
class EmbeddingModelsConfig(BaseSettings):
|
11
13
|
model_type: str = "openai"
|
12
14
|
dims: int = 0
|
15
|
+
context_length: int = 512
|
16
|
+
batch_size: int = 512
|
13
17
|
|
14
18
|
|
15
19
|
class EmbeddingModel(ABC):
|
@@ -25,8 +29,14 @@ class EmbeddingModel(ABC):
|
|
25
29
|
SentenceTransformerEmbeddings,
|
26
30
|
SentenceTransformerEmbeddingsConfig,
|
27
31
|
)
|
32
|
+
from langroid.embedding_models.remote_embeds import (
|
33
|
+
RemoteEmbeddings,
|
34
|
+
RemoteEmbeddingsConfig,
|
35
|
+
)
|
28
36
|
|
29
|
-
if isinstance(config,
|
37
|
+
if isinstance(config, RemoteEmbeddingsConfig):
|
38
|
+
return RemoteEmbeddings(config)
|
39
|
+
elif isinstance(config, OpenAIEmbeddingsConfig):
|
30
40
|
return OpenAIEmbeddings(config)
|
31
41
|
elif isinstance(config, SentenceTransformerEmbeddingsConfig):
|
32
42
|
return SentenceTransformerEmbeddings(config)
|
@@ -41,3 +51,12 @@ class EmbeddingModel(ABC):
|
|
41
51
|
@abstractmethod
|
42
52
|
def embedding_dims(self) -> int:
|
43
53
|
pass
|
54
|
+
|
55
|
+
def similarity(self, text1: str, text2: str) -> float:
|
56
|
+
"""Compute cosine similarity between two texts."""
|
57
|
+
[emb1, emb2] = self.embedding_fn()([text1, text2])
|
58
|
+
return float(
|
59
|
+
np.array(emb1)
|
60
|
+
@ np.array(emb2)
|
61
|
+
/ (np.linalg.norm(emb1) * np.linalg.norm(emb2))
|
62
|
+
)
|
@@ -1,32 +1,98 @@
|
|
1
|
+
import atexit
|
1
2
|
import os
|
2
|
-
from typing import Callable, List
|
3
|
+
from typing import Callable, List, Optional
|
3
4
|
|
4
|
-
import
|
5
|
+
import tiktoken
|
5
6
|
from dotenv import load_dotenv
|
7
|
+
from openai import OpenAI
|
6
8
|
|
7
9
|
from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
|
8
|
-
from langroid.language_models.utils import retry_with_exponential_backoff
|
9
10
|
from langroid.mytypes import Embeddings
|
11
|
+
from langroid.parsing.utils import batched
|
10
12
|
|
11
13
|
|
12
14
|
class OpenAIEmbeddingsConfig(EmbeddingModelsConfig):
|
13
15
|
model_type: str = "openai"
|
14
16
|
model_name: str = "text-embedding-ada-002"
|
15
17
|
api_key: str = ""
|
18
|
+
api_base: Optional[str] = None
|
19
|
+
organization: str = ""
|
16
20
|
dims: int = 1536
|
21
|
+
context_length: int = 8192
|
17
22
|
|
18
23
|
|
19
24
|
class SentenceTransformerEmbeddingsConfig(EmbeddingModelsConfig):
|
20
25
|
model_type: str = "sentence-transformer"
|
21
26
|
model_name: str = "BAAI/bge-large-en-v1.5"
|
27
|
+
context_length: int = 512
|
28
|
+
data_parallel: bool = False
|
29
|
+
# Select device (e.g. "cuda", "cpu") when data parallel is disabled
|
30
|
+
device: Optional[str] = None
|
31
|
+
# Select devices when data parallel is enabled
|
32
|
+
devices: Optional[list[str]] = None
|
33
|
+
|
34
|
+
|
35
|
+
class EmbeddingFunctionCallable:
|
36
|
+
"""
|
37
|
+
A callable class designed to generate embeddings for a list of texts using
|
38
|
+
the OpenAI API, with automatic retries on failure.
|
39
|
+
|
40
|
+
Attributes:
|
41
|
+
model (OpenAIEmbeddings): An instance of OpenAIEmbeddings that provides
|
42
|
+
configuration and utilities for generating embeddings.
|
43
|
+
|
44
|
+
Methods:
|
45
|
+
__call__(input: List[str]) -> Embeddings: Generate embeddings for
|
46
|
+
a list of input texts.
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(self, model: "OpenAIEmbeddings", batch_size: int = 512):
|
50
|
+
"""
|
51
|
+
Initialize the EmbeddingFunctionCallable with a specific model.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
model (OpenAIEmbeddings): An instance of OpenAIEmbeddings to use for
|
55
|
+
generating embeddings.
|
56
|
+
batch_size (int): Batch size
|
57
|
+
"""
|
58
|
+
self.model = model
|
59
|
+
self.batch_size = batch_size
|
60
|
+
|
61
|
+
def __call__(self, input: List[str]) -> Embeddings:
|
62
|
+
"""
|
63
|
+
Generate embeddings for a given list of input texts using the OpenAI API,
|
64
|
+
with retries on failure.
|
65
|
+
|
66
|
+
This method:
|
67
|
+
- Truncates each text in the input list to the model's maximum context length.
|
68
|
+
- Processes the texts in batches to generate embeddings efficiently.
|
69
|
+
- Automatically retries the embedding generation process with exponential
|
70
|
+
backoff in case of failures.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
input (List[str]): A list of input texts to generate embeddings for.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
Embeddings: A list of embedding vectors corresponding to the input texts.
|
77
|
+
"""
|
78
|
+
tokenized_texts = self.model.truncate_texts(input)
|
79
|
+
embeds = []
|
80
|
+
for batch in batched(tokenized_texts, self.batch_size):
|
81
|
+
result = self.model.client.embeddings.create(
|
82
|
+
input=batch, model=self.model.config.model_name
|
83
|
+
)
|
84
|
+
batch_embeds = [d.embedding for d in result.data]
|
85
|
+
embeds.extend(batch_embeds)
|
86
|
+
return embeds
|
22
87
|
|
23
88
|
|
24
89
|
class OpenAIEmbeddings(EmbeddingModel):
|
25
|
-
def __init__(self, config: OpenAIEmbeddingsConfig):
|
90
|
+
def __init__(self, config: OpenAIEmbeddingsConfig = OpenAIEmbeddingsConfig()):
|
26
91
|
super().__init__()
|
27
92
|
self.config = config
|
28
93
|
load_dotenv()
|
29
94
|
self.config.api_key = os.getenv("OPENAI_API_KEY", "")
|
95
|
+
self.config.organization = os.getenv("OPENAI_ORGANIZATION", "")
|
30
96
|
if self.config.api_key == "":
|
31
97
|
raise ValueError(
|
32
98
|
"""OPENAI_API_KEY env variable must be set to use
|
@@ -34,28 +100,38 @@ class OpenAIEmbeddings(EmbeddingModel):
|
|
34
100
|
in your .env file.
|
35
101
|
"""
|
36
102
|
)
|
37
|
-
|
103
|
+
self.client = OpenAI(base_url=self.config.api_base, api_key=self.config.api_key)
|
104
|
+
self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
|
105
|
+
|
106
|
+
def truncate_texts(self, texts: List[str]) -> List[List[int]]:
|
107
|
+
"""
|
108
|
+
Truncate texts to the embedding model's context length.
|
109
|
+
TODO: Maybe we should show warning, and consider doing T5 summarization?
|
110
|
+
"""
|
111
|
+
return [
|
112
|
+
self.tokenizer.encode(text, disallowed_special=())[
|
113
|
+
: self.config.context_length
|
114
|
+
]
|
115
|
+
for text in texts
|
116
|
+
]
|
38
117
|
|
39
118
|
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
40
|
-
|
41
|
-
def fn(texts: List[str]) -> Embeddings:
|
42
|
-
result = openai.Embedding.create( # type: ignore
|
43
|
-
input=texts, model=self.config.model_name
|
44
|
-
)
|
45
|
-
return [d["embedding"] for d in result["data"]]
|
46
|
-
|
47
|
-
return fn
|
119
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
48
120
|
|
49
121
|
@property
|
50
122
|
def embedding_dims(self) -> int:
|
51
123
|
return self.config.dims
|
52
124
|
|
53
125
|
|
126
|
+
STEC = SentenceTransformerEmbeddingsConfig
|
127
|
+
|
128
|
+
|
54
129
|
class SentenceTransformerEmbeddings(EmbeddingModel):
|
55
|
-
def __init__(self, config:
|
130
|
+
def __init__(self, config: STEC = STEC()):
|
56
131
|
# this is an "extra" optional dependency, so we import it here
|
57
132
|
try:
|
58
133
|
from sentence_transformers import SentenceTransformer
|
134
|
+
from transformers import AutoTokenizer
|
59
135
|
except ImportError:
|
60
136
|
raise ImportError(
|
61
137
|
"""
|
@@ -67,13 +143,39 @@ class SentenceTransformerEmbeddings(EmbeddingModel):
|
|
67
143
|
|
68
144
|
super().__init__()
|
69
145
|
self.config = config
|
70
|
-
|
146
|
+
|
147
|
+
self.model = SentenceTransformer(
|
148
|
+
self.config.model_name,
|
149
|
+
device=self.config.device,
|
150
|
+
)
|
151
|
+
if self.config.data_parallel:
|
152
|
+
self.pool = self.model.start_multi_process_pool(
|
153
|
+
self.config.devices # type: ignore
|
154
|
+
)
|
155
|
+
atexit.register(
|
156
|
+
lambda: SentenceTransformer.stop_multi_process_pool(self.pool)
|
157
|
+
)
|
158
|
+
|
159
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
160
|
+
self.config.context_length = self.tokenizer.model_max_length
|
71
161
|
|
72
162
|
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
73
163
|
def fn(texts: List[str]) -> Embeddings:
|
74
|
-
|
75
|
-
|
76
|
-
|
164
|
+
if self.config.data_parallel:
|
165
|
+
embeds: Embeddings = self.model.encode_multi_process(
|
166
|
+
texts,
|
167
|
+
self.pool,
|
168
|
+
batch_size=self.config.batch_size,
|
169
|
+
).tolist()
|
170
|
+
else:
|
171
|
+
embeds = []
|
172
|
+
for batch in batched(texts, self.config.batch_size):
|
173
|
+
batch_embeds = self.model.encode(
|
174
|
+
batch, convert_to_numpy=True
|
175
|
+
).tolist() # type: ignore
|
176
|
+
embeds.extend(batch_embeds)
|
177
|
+
|
178
|
+
return embeds
|
77
179
|
|
78
180
|
return fn
|
79
181
|
|
@@ -0,0 +1,19 @@
|
|
1
|
+
syntax = "proto3";
|
2
|
+
|
3
|
+
service Embedding {
|
4
|
+
rpc Embed (EmbeddingRequest) returns (BatchEmbeds) {};
|
5
|
+
}
|
6
|
+
|
7
|
+
message EmbeddingRequest {
|
8
|
+
string model_name = 1;
|
9
|
+
int32 batch_size = 2;
|
10
|
+
repeated string strings = 3;
|
11
|
+
}
|
12
|
+
|
13
|
+
message BatchEmbeds {
|
14
|
+
repeated Embed embeds = 1;
|
15
|
+
}
|
16
|
+
|
17
|
+
message Embed {
|
18
|
+
repeated float embed = 1;
|
19
|
+
}
|
@@ -0,0 +1,33 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3
|
+
# source: embeddings.proto
|
4
|
+
# Protobuf Python Version: 4.25.1
|
5
|
+
"""Generated protocol buffer code."""
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
9
|
+
from google.protobuf.internal import builder as _builder
|
10
|
+
|
11
|
+
# @@protoc_insertion_point(imports)
|
12
|
+
|
13
|
+
_sym_db = _symbol_database.Default()
|
14
|
+
|
15
|
+
|
16
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
17
|
+
b'\n\x10\x65mbeddings.proto"K\n\x10\x45mbeddingRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x02 \x01(\x05\x12\x0f\n\x07strings\x18\x03 \x03(\t"%\n\x0b\x42\x61tchEmbeds\x12\x16\n\x06\x65mbeds\x18\x01 \x03(\x0b\x32\x06.Embed"\x16\n\x05\x45mbed\x12\r\n\x05\x65mbed\x18\x01 \x03(\x02\x32\x37\n\tEmbedding\x12*\n\x05\x45mbed\x12\x11.EmbeddingRequest\x1a\x0c.BatchEmbeds"\x00\x62\x06proto3'
|
18
|
+
)
|
19
|
+
|
20
|
+
_globals = globals()
|
21
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
22
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embeddings_pb2", _globals)
|
23
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
24
|
+
DESCRIPTOR._options = None
|
25
|
+
_globals["_EMBEDDINGREQUEST"]._serialized_start = 20
|
26
|
+
_globals["_EMBEDDINGREQUEST"]._serialized_end = 95
|
27
|
+
_globals["_BATCHEMBEDS"]._serialized_start = 97
|
28
|
+
_globals["_BATCHEMBEDS"]._serialized_end = 134
|
29
|
+
_globals["_EMBED"]._serialized_start = 136
|
30
|
+
_globals["_EMBED"]._serialized_end = 158
|
31
|
+
_globals["_EMBEDDING"]._serialized_start = 160
|
32
|
+
_globals["_EMBEDDING"]._serialized_end = 215
|
33
|
+
# @@protoc_insertion_point(module_scope)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
from typing import (
|
2
|
+
ClassVar as _ClassVar,
|
3
|
+
)
|
4
|
+
from typing import (
|
5
|
+
Iterable as _Iterable,
|
6
|
+
)
|
7
|
+
from typing import (
|
8
|
+
Mapping as _Mapping,
|
9
|
+
)
|
10
|
+
from typing import (
|
11
|
+
Optional as _Optional,
|
12
|
+
)
|
13
|
+
from typing import (
|
14
|
+
Union as _Union,
|
15
|
+
)
|
16
|
+
|
17
|
+
from google.protobuf import descriptor as _descriptor
|
18
|
+
from google.protobuf import message as _message
|
19
|
+
from google.protobuf.internal import containers as _containers
|
20
|
+
|
21
|
+
DESCRIPTOR: _descriptor.FileDescriptor
|
22
|
+
|
23
|
+
class EmbeddingRequest(_message.Message):
|
24
|
+
__slots__ = ("model_name", "batch_size", "strings")
|
25
|
+
MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
|
26
|
+
BATCH_SIZE_FIELD_NUMBER: _ClassVar[int]
|
27
|
+
STRINGS_FIELD_NUMBER: _ClassVar[int]
|
28
|
+
model_name: str
|
29
|
+
batch_size: int
|
30
|
+
strings: _containers.RepeatedScalarFieldContainer[str]
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
model_name: _Optional[str] = ...,
|
34
|
+
batch_size: _Optional[int] = ...,
|
35
|
+
strings: _Optional[_Iterable[str]] = ...,
|
36
|
+
) -> None: ...
|
37
|
+
|
38
|
+
class BatchEmbeds(_message.Message):
|
39
|
+
__slots__ = ("embeds",)
|
40
|
+
EMBEDS_FIELD_NUMBER: _ClassVar[int]
|
41
|
+
embeds: _containers.RepeatedCompositeFieldContainer[Embed]
|
42
|
+
def __init__(
|
43
|
+
self, embeds: _Optional[_Iterable[_Union[Embed, _Mapping]]] = ...
|
44
|
+
) -> None: ...
|
45
|
+
|
46
|
+
class Embed(_message.Message):
|
47
|
+
__slots__ = ("embed",)
|
48
|
+
EMBED_FIELD_NUMBER: _ClassVar[int]
|
49
|
+
embed: _containers.RepeatedScalarFieldContainer[float]
|
50
|
+
def __init__(self, embed: _Optional[_Iterable[float]] = ...) -> None: ...
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
2
|
+
"""Client and server classes corresponding to protobuf-defined services."""
|
3
|
+
import grpc
|
4
|
+
|
5
|
+
import langroid.embedding_models.protoc.embeddings_pb2 as embeddings__pb2
|
6
|
+
|
7
|
+
|
8
|
+
class EmbeddingStub(object):
|
9
|
+
"""Missing associated documentation comment in .proto file."""
|
10
|
+
|
11
|
+
def __init__(self, channel):
|
12
|
+
"""Constructor.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
channel: A grpc.Channel.
|
16
|
+
"""
|
17
|
+
self.Embed = channel.unary_unary(
|
18
|
+
"/Embedding/Embed",
|
19
|
+
request_serializer=embeddings__pb2.EmbeddingRequest.SerializeToString,
|
20
|
+
response_deserializer=embeddings__pb2.BatchEmbeds.FromString,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class EmbeddingServicer(object):
|
25
|
+
"""Missing associated documentation comment in .proto file."""
|
26
|
+
|
27
|
+
def Embed(self, request, context):
|
28
|
+
"""Missing associated documentation comment in .proto file."""
|
29
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
30
|
+
context.set_details("Method not implemented!")
|
31
|
+
raise NotImplementedError("Method not implemented!")
|
32
|
+
|
33
|
+
|
34
|
+
def add_EmbeddingServicer_to_server(servicer, server):
|
35
|
+
rpc_method_handlers = {
|
36
|
+
"Embed": grpc.unary_unary_rpc_method_handler(
|
37
|
+
servicer.Embed,
|
38
|
+
request_deserializer=embeddings__pb2.EmbeddingRequest.FromString,
|
39
|
+
response_serializer=embeddings__pb2.BatchEmbeds.SerializeToString,
|
40
|
+
),
|
41
|
+
}
|
42
|
+
generic_handler = grpc.method_handlers_generic_handler(
|
43
|
+
"Embedding", rpc_method_handlers
|
44
|
+
)
|
45
|
+
server.add_generic_rpc_handlers((generic_handler,))
|
46
|
+
|
47
|
+
|
48
|
+
# This class is part of an EXPERIMENTAL API.
|
49
|
+
class Embedding(object):
|
50
|
+
"""Missing associated documentation comment in .proto file."""
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def Embed(
|
54
|
+
request,
|
55
|
+
target,
|
56
|
+
options=(),
|
57
|
+
channel_credentials=None,
|
58
|
+
call_credentials=None,
|
59
|
+
insecure=False,
|
60
|
+
compression=None,
|
61
|
+
wait_for_ready=None,
|
62
|
+
timeout=None,
|
63
|
+
metadata=None,
|
64
|
+
):
|
65
|
+
return grpc.experimental.unary_unary(
|
66
|
+
request,
|
67
|
+
target,
|
68
|
+
"/Embedding/Embed",
|
69
|
+
embeddings__pb2.EmbeddingRequest.SerializeToString,
|
70
|
+
embeddings__pb2.BatchEmbeds.FromString,
|
71
|
+
options,
|
72
|
+
channel_credentials,
|
73
|
+
insecure,
|
74
|
+
call_credentials,
|
75
|
+
compression,
|
76
|
+
wait_for_ready,
|
77
|
+
timeout,
|
78
|
+
metadata,
|
79
|
+
)
|