langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -1,7 +1,8 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
- from typing import Any, Dict, Optional
4
+ from contextlib import AbstractContextManager, contextmanager
5
+ from typing import Any, Dict, List, TypeVar
5
6
 
6
7
  import fakeredis
7
8
  import redis
@@ -10,6 +11,7 @@ from pydantic import BaseModel
10
11
 
11
12
  from langroid.cachedb.base import CacheDB
12
13
 
14
+ T = TypeVar("T", bound="RedisCache")
13
15
  logger = logging.getLogger(__name__)
14
16
 
15
17
 
@@ -33,7 +35,7 @@ class RedisCache(CacheDB):
33
35
  load_dotenv()
34
36
 
35
37
  if self.config.fake:
36
- self.client = fakeredis.FakeStrictRedis() # type: ignore
38
+ self.pool = fakeredis.FakeStrictRedis() # type: ignore
37
39
  else:
38
40
  redis_password = os.getenv("REDIS_PASSWORD")
39
41
  redis_host = os.getenv("REDIS_HOST")
@@ -43,21 +45,46 @@ class RedisCache(CacheDB):
43
45
  """REDIS_PASSWORD, REDIS_HOST, REDIS_PORT not set in .env file,
44
46
  using fake redis client"""
45
47
  )
46
- self.client = fakeredis.FakeStrictRedis() # type: ignore
48
+ self.pool = fakeredis.FakeStrictRedis() # type: ignore
47
49
  else:
48
- self.client = redis.Redis( # type: ignore
50
+ self.pool = redis.ConnectionPool( # type: ignore
49
51
  host=redis_host,
50
52
  port=redis_port,
51
53
  password=redis_password,
54
+ max_connections=50,
55
+ socket_timeout=5,
56
+ socket_keepalive=True,
57
+ retry_on_timeout=True,
58
+ health_check_interval=30,
52
59
  )
53
60
 
61
+ @contextmanager # type: ignore
62
+ def redis_client(self) -> AbstractContextManager[T]: # type: ignore
63
+ """Cleanly open and close a redis client, avoids max clients exceeded error"""
64
+ if isinstance(self.pool, fakeredis.FakeStrictRedis):
65
+ yield self.pool
66
+ else:
67
+ client: T = redis.Redis(connection_pool=self.pool)
68
+ try:
69
+ yield client
70
+ finally:
71
+ client.close()
72
+
73
+ def close_all_connections(self) -> None:
74
+ with self.redis_client() as client: # type: ignore
75
+ clients = client.client_list()
76
+ for c in clients:
77
+ client.client_kill(c["addr"])
78
+
54
79
  def clear(self) -> None:
55
80
  """Clear keys from current db."""
56
- self.client.flushdb()
81
+ with self.redis_client() as client: # type: ignore
82
+ client.flushdb()
57
83
 
58
84
  def clear_all(self) -> None:
59
85
  """Clear all keys from all dbs."""
60
- self.client.flushall()
86
+ with self.redis_client() as client: # type: ignore
87
+ client.flushall()
61
88
 
62
89
  def store(self, key: str, value: Any) -> None:
63
90
  """
@@ -67,9 +94,14 @@ class RedisCache(CacheDB):
67
94
  key (str): The key under which to store the value.
68
95
  value (Any): The value to store.
69
96
  """
70
- self.client.set(key, json.dumps(value))
71
-
72
- def retrieve(self, key: str) -> Optional[Dict[str, Any]]:
97
+ with self.redis_client() as client: # type: ignore
98
+ try:
99
+ client.set(key, json.dumps(value))
100
+ except redis.exceptions.ConnectionError:
101
+ logger.warning("Redis connection error, not storing key/value")
102
+ return None
103
+
104
+ def retrieve(self, key: str) -> Dict[str, Any] | str | None:
73
105
  """
74
106
  Retrieve the value associated with a key.
75
107
 
@@ -79,5 +111,40 @@ class RedisCache(CacheDB):
79
111
  Returns:
80
112
  dict: The value associated with the key.
81
113
  """
82
- value = self.client.get(key)
83
- return json.loads(value) if value else None
114
+ with self.redis_client() as client: # type: ignore
115
+ try:
116
+ value = client.get(key)
117
+ except redis.exceptions.ConnectionError:
118
+ logger.warning("Redis connection error, returning None")
119
+ return None
120
+ return json.loads(value) if value else None
121
+
122
+ def delete_keys(self, keys: List[str]) -> None:
123
+ """
124
+ Delete the keys from the cache.
125
+
126
+ Args:
127
+ keys (List[str]): The keys to delete.
128
+ """
129
+ with self.redis_client() as client: # type: ignore
130
+ try:
131
+ client.delete(*keys)
132
+ except redis.exceptions.ConnectionError:
133
+ logger.warning("Redis connection error, not deleting keys")
134
+ return None
135
+
136
+ def delete_keys_pattern(self, pattern: str) -> None:
137
+ """
138
+ Delete the keys matching the pattern from the cache.
139
+
140
+ Args:
141
+ prefix (str): The pattern to match.
142
+ """
143
+ with self.redis_client() as client: # type: ignore
144
+ try:
145
+ keys = client.keys(pattern)
146
+ if len(keys) > 0:
147
+ client.delete(*keys)
148
+ except redis.exceptions.ConnectionError:
149
+ logger.warning("Redis connection error, not deleting keys")
150
+ return None
@@ -0,0 +1,34 @@
1
+ from . import base
2
+ from . import models
3
+ from . import remote_embeds
4
+
5
+ from .base import (
6
+ EmbeddingModel,
7
+ EmbeddingModelsConfig,
8
+ )
9
+ from .models import (
10
+ OpenAIEmbeddings,
11
+ OpenAIEmbeddingsConfig,
12
+ SentenceTransformerEmbeddingsConfig,
13
+ SentenceTransformerEmbeddings,
14
+ embedding_model,
15
+ )
16
+ from .remote_embeds import (
17
+ RemoteEmbeddingsConfig,
18
+ RemoteEmbeddings,
19
+ )
20
+
21
+ __all__ = [
22
+ "base",
23
+ "models",
24
+ "remote_embeds",
25
+ "EmbeddingModel",
26
+ "EmbeddingModelsConfig",
27
+ "OpenAIEmbeddings",
28
+ "OpenAIEmbeddingsConfig",
29
+ "SentenceTransformerEmbeddingsConfig",
30
+ "SentenceTransformerEmbeddings",
31
+ "embedding_model",
32
+ "RemoteEmbeddingsConfig",
33
+ "RemoteEmbeddings",
34
+ ]
@@ -1,15 +1,19 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
3
 
4
- from chromadb.api.types import EmbeddingFunction
4
+ import numpy as np
5
5
  from pydantic import BaseSettings
6
6
 
7
+ from langroid.mytypes import EmbeddingFunction
8
+
7
9
  logging.getLogger("openai").setLevel(logging.ERROR)
8
10
 
9
11
 
10
12
  class EmbeddingModelsConfig(BaseSettings):
11
13
  model_type: str = "openai"
12
14
  dims: int = 0
15
+ context_length: int = 512
16
+ batch_size: int = 512
13
17
 
14
18
 
15
19
  class EmbeddingModel(ABC):
@@ -25,8 +29,14 @@ class EmbeddingModel(ABC):
25
29
  SentenceTransformerEmbeddings,
26
30
  SentenceTransformerEmbeddingsConfig,
27
31
  )
32
+ from langroid.embedding_models.remote_embeds import (
33
+ RemoteEmbeddings,
34
+ RemoteEmbeddingsConfig,
35
+ )
28
36
 
29
- if isinstance(config, OpenAIEmbeddingsConfig):
37
+ if isinstance(config, RemoteEmbeddingsConfig):
38
+ return RemoteEmbeddings(config)
39
+ elif isinstance(config, OpenAIEmbeddingsConfig):
30
40
  return OpenAIEmbeddings(config)
31
41
  elif isinstance(config, SentenceTransformerEmbeddingsConfig):
32
42
  return SentenceTransformerEmbeddings(config)
@@ -41,3 +51,12 @@ class EmbeddingModel(ABC):
41
51
  @abstractmethod
42
52
  def embedding_dims(self) -> int:
43
53
  pass
54
+
55
+ def similarity(self, text1: str, text2: str) -> float:
56
+ """Compute cosine similarity between two texts."""
57
+ [emb1, emb2] = self.embedding_fn()([text1, text2])
58
+ return float(
59
+ np.array(emb1)
60
+ @ np.array(emb2)
61
+ / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
62
+ )
@@ -1,32 +1,98 @@
1
+ import atexit
1
2
  import os
2
- from typing import Callable, List
3
+ from typing import Callable, List, Optional
3
4
 
4
- import openai
5
+ import tiktoken
5
6
  from dotenv import load_dotenv
7
+ from openai import OpenAI
6
8
 
7
9
  from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
8
- from langroid.language_models.utils import retry_with_exponential_backoff
9
10
  from langroid.mytypes import Embeddings
11
+ from langroid.parsing.utils import batched
10
12
 
11
13
 
12
14
  class OpenAIEmbeddingsConfig(EmbeddingModelsConfig):
13
15
  model_type: str = "openai"
14
16
  model_name: str = "text-embedding-ada-002"
15
17
  api_key: str = ""
18
+ api_base: Optional[str] = None
19
+ organization: str = ""
16
20
  dims: int = 1536
21
+ context_length: int = 8192
17
22
 
18
23
 
19
24
  class SentenceTransformerEmbeddingsConfig(EmbeddingModelsConfig):
20
25
  model_type: str = "sentence-transformer"
21
26
  model_name: str = "BAAI/bge-large-en-v1.5"
27
+ context_length: int = 512
28
+ data_parallel: bool = False
29
+ # Select device (e.g. "cuda", "cpu") when data parallel is disabled
30
+ device: Optional[str] = None
31
+ # Select devices when data parallel is enabled
32
+ devices: Optional[list[str]] = None
33
+
34
+
35
+ class EmbeddingFunctionCallable:
36
+ """
37
+ A callable class designed to generate embeddings for a list of texts using
38
+ the OpenAI API, with automatic retries on failure.
39
+
40
+ Attributes:
41
+ model (OpenAIEmbeddings): An instance of OpenAIEmbeddings that provides
42
+ configuration and utilities for generating embeddings.
43
+
44
+ Methods:
45
+ __call__(input: List[str]) -> Embeddings: Generate embeddings for
46
+ a list of input texts.
47
+ """
48
+
49
+ def __init__(self, model: "OpenAIEmbeddings", batch_size: int = 512):
50
+ """
51
+ Initialize the EmbeddingFunctionCallable with a specific model.
52
+
53
+ Args:
54
+ model (OpenAIEmbeddings): An instance of OpenAIEmbeddings to use for
55
+ generating embeddings.
56
+ batch_size (int): Batch size
57
+ """
58
+ self.model = model
59
+ self.batch_size = batch_size
60
+
61
+ def __call__(self, input: List[str]) -> Embeddings:
62
+ """
63
+ Generate embeddings for a given list of input texts using the OpenAI API,
64
+ with retries on failure.
65
+
66
+ This method:
67
+ - Truncates each text in the input list to the model's maximum context length.
68
+ - Processes the texts in batches to generate embeddings efficiently.
69
+ - Automatically retries the embedding generation process with exponential
70
+ backoff in case of failures.
71
+
72
+ Args:
73
+ input (List[str]): A list of input texts to generate embeddings for.
74
+
75
+ Returns:
76
+ Embeddings: A list of embedding vectors corresponding to the input texts.
77
+ """
78
+ tokenized_texts = self.model.truncate_texts(input)
79
+ embeds = []
80
+ for batch in batched(tokenized_texts, self.batch_size):
81
+ result = self.model.client.embeddings.create(
82
+ input=batch, model=self.model.config.model_name
83
+ )
84
+ batch_embeds = [d.embedding for d in result.data]
85
+ embeds.extend(batch_embeds)
86
+ return embeds
22
87
 
23
88
 
24
89
  class OpenAIEmbeddings(EmbeddingModel):
25
- def __init__(self, config: OpenAIEmbeddingsConfig):
90
+ def __init__(self, config: OpenAIEmbeddingsConfig = OpenAIEmbeddingsConfig()):
26
91
  super().__init__()
27
92
  self.config = config
28
93
  load_dotenv()
29
94
  self.config.api_key = os.getenv("OPENAI_API_KEY", "")
95
+ self.config.organization = os.getenv("OPENAI_ORGANIZATION", "")
30
96
  if self.config.api_key == "":
31
97
  raise ValueError(
32
98
  """OPENAI_API_KEY env variable must be set to use
@@ -34,28 +100,38 @@ class OpenAIEmbeddings(EmbeddingModel):
34
100
  in your .env file.
35
101
  """
36
102
  )
37
- openai.api_key = self.config.api_key
103
+ self.client = OpenAI(base_url=self.config.api_base, api_key=self.config.api_key)
104
+ self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
105
+
106
+ def truncate_texts(self, texts: List[str]) -> List[List[int]]:
107
+ """
108
+ Truncate texts to the embedding model's context length.
109
+ TODO: Maybe we should show warning, and consider doing T5 summarization?
110
+ """
111
+ return [
112
+ self.tokenizer.encode(text, disallowed_special=())[
113
+ : self.config.context_length
114
+ ]
115
+ for text in texts
116
+ ]
38
117
 
39
118
  def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
40
- @retry_with_exponential_backoff
41
- def fn(texts: List[str]) -> Embeddings:
42
- result = openai.Embedding.create( # type: ignore
43
- input=texts, model=self.config.model_name
44
- )
45
- return [d["embedding"] for d in result["data"]]
46
-
47
- return fn
119
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
48
120
 
49
121
  @property
50
122
  def embedding_dims(self) -> int:
51
123
  return self.config.dims
52
124
 
53
125
 
126
+ STEC = SentenceTransformerEmbeddingsConfig
127
+
128
+
54
129
  class SentenceTransformerEmbeddings(EmbeddingModel):
55
- def __init__(self, config: SentenceTransformerEmbeddingsConfig):
130
+ def __init__(self, config: STEC = STEC()):
56
131
  # this is an "extra" optional dependency, so we import it here
57
132
  try:
58
133
  from sentence_transformers import SentenceTransformer
134
+ from transformers import AutoTokenizer
59
135
  except ImportError:
60
136
  raise ImportError(
61
137
  """
@@ -67,13 +143,39 @@ class SentenceTransformerEmbeddings(EmbeddingModel):
67
143
 
68
144
  super().__init__()
69
145
  self.config = config
70
- self.model = SentenceTransformer(self.config.model_name)
146
+
147
+ self.model = SentenceTransformer(
148
+ self.config.model_name,
149
+ device=self.config.device,
150
+ )
151
+ if self.config.data_parallel:
152
+ self.pool = self.model.start_multi_process_pool(
153
+ self.config.devices # type: ignore
154
+ )
155
+ atexit.register(
156
+ lambda: SentenceTransformer.stop_multi_process_pool(self.pool)
157
+ )
158
+
159
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
160
+ self.config.context_length = self.tokenizer.model_max_length
71
161
 
72
162
  def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
73
163
  def fn(texts: List[str]) -> Embeddings:
74
- return self.model.encode( # type: ignore
75
- texts, convert_to_numpy=True
76
- ).tolist()
164
+ if self.config.data_parallel:
165
+ embeds: Embeddings = self.model.encode_multi_process(
166
+ texts,
167
+ self.pool,
168
+ batch_size=self.config.batch_size,
169
+ ).tolist()
170
+ else:
171
+ embeds = []
172
+ for batch in batched(texts, self.config.batch_size):
173
+ batch_embeds = self.model.encode(
174
+ batch, convert_to_numpy=True
175
+ ).tolist() # type: ignore
176
+ embeds.extend(batch_embeds)
177
+
178
+ return embeds
77
179
 
78
180
  return fn
79
181
 
@@ -0,0 +1,19 @@
1
+ syntax = "proto3";
2
+
3
+ service Embedding {
4
+ rpc Embed (EmbeddingRequest) returns (BatchEmbeds) {};
5
+ }
6
+
7
+ message EmbeddingRequest {
8
+ string model_name = 1;
9
+ int32 batch_size = 2;
10
+ repeated string strings = 3;
11
+ }
12
+
13
+ message BatchEmbeds {
14
+ repeated Embed embeds = 1;
15
+ }
16
+
17
+ message Embed {
18
+ repeated float embed = 1;
19
+ }
@@ -0,0 +1,33 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: embeddings.proto
4
+ # Protobuf Python Version: 4.25.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+
11
+ # @@protoc_insertion_point(imports)
12
+
13
+ _sym_db = _symbol_database.Default()
14
+
15
+
16
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
+ b'\n\x10\x65mbeddings.proto"K\n\x10\x45mbeddingRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x02 \x01(\x05\x12\x0f\n\x07strings\x18\x03 \x03(\t"%\n\x0b\x42\x61tchEmbeds\x12\x16\n\x06\x65mbeds\x18\x01 \x03(\x0b\x32\x06.Embed"\x16\n\x05\x45mbed\x12\r\n\x05\x65mbed\x18\x01 \x03(\x02\x32\x37\n\tEmbedding\x12*\n\x05\x45mbed\x12\x11.EmbeddingRequest\x1a\x0c.BatchEmbeds"\x00\x62\x06proto3'
18
+ )
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embeddings_pb2", _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals["_EMBEDDINGREQUEST"]._serialized_start = 20
26
+ _globals["_EMBEDDINGREQUEST"]._serialized_end = 95
27
+ _globals["_BATCHEMBEDS"]._serialized_start = 97
28
+ _globals["_BATCHEMBEDS"]._serialized_end = 134
29
+ _globals["_EMBED"]._serialized_start = 136
30
+ _globals["_EMBED"]._serialized_end = 158
31
+ _globals["_EMBEDDING"]._serialized_start = 160
32
+ _globals["_EMBEDDING"]._serialized_end = 215
33
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,50 @@
1
+ from typing import (
2
+ ClassVar as _ClassVar,
3
+ )
4
+ from typing import (
5
+ Iterable as _Iterable,
6
+ )
7
+ from typing import (
8
+ Mapping as _Mapping,
9
+ )
10
+ from typing import (
11
+ Optional as _Optional,
12
+ )
13
+ from typing import (
14
+ Union as _Union,
15
+ )
16
+
17
+ from google.protobuf import descriptor as _descriptor
18
+ from google.protobuf import message as _message
19
+ from google.protobuf.internal import containers as _containers
20
+
21
+ DESCRIPTOR: _descriptor.FileDescriptor
22
+
23
+ class EmbeddingRequest(_message.Message):
24
+ __slots__ = ("model_name", "batch_size", "strings")
25
+ MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
26
+ BATCH_SIZE_FIELD_NUMBER: _ClassVar[int]
27
+ STRINGS_FIELD_NUMBER: _ClassVar[int]
28
+ model_name: str
29
+ batch_size: int
30
+ strings: _containers.RepeatedScalarFieldContainer[str]
31
+ def __init__(
32
+ self,
33
+ model_name: _Optional[str] = ...,
34
+ batch_size: _Optional[int] = ...,
35
+ strings: _Optional[_Iterable[str]] = ...,
36
+ ) -> None: ...
37
+
38
+ class BatchEmbeds(_message.Message):
39
+ __slots__ = ("embeds",)
40
+ EMBEDS_FIELD_NUMBER: _ClassVar[int]
41
+ embeds: _containers.RepeatedCompositeFieldContainer[Embed]
42
+ def __init__(
43
+ self, embeds: _Optional[_Iterable[_Union[Embed, _Mapping]]] = ...
44
+ ) -> None: ...
45
+
46
+ class Embed(_message.Message):
47
+ __slots__ = ("embed",)
48
+ EMBED_FIELD_NUMBER: _ClassVar[int]
49
+ embed: _containers.RepeatedScalarFieldContainer[float]
50
+ def __init__(self, embed: _Optional[_Iterable[float]] = ...) -> None: ...
@@ -0,0 +1,79 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
5
+ import langroid.embedding_models.protoc.embeddings_pb2 as embeddings__pb2
6
+
7
+
8
+ class EmbeddingStub(object):
9
+ """Missing associated documentation comment in .proto file."""
10
+
11
+ def __init__(self, channel):
12
+ """Constructor.
13
+
14
+ Args:
15
+ channel: A grpc.Channel.
16
+ """
17
+ self.Embed = channel.unary_unary(
18
+ "/Embedding/Embed",
19
+ request_serializer=embeddings__pb2.EmbeddingRequest.SerializeToString,
20
+ response_deserializer=embeddings__pb2.BatchEmbeds.FromString,
21
+ )
22
+
23
+
24
+ class EmbeddingServicer(object):
25
+ """Missing associated documentation comment in .proto file."""
26
+
27
+ def Embed(self, request, context):
28
+ """Missing associated documentation comment in .proto file."""
29
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
30
+ context.set_details("Method not implemented!")
31
+ raise NotImplementedError("Method not implemented!")
32
+
33
+
34
+ def add_EmbeddingServicer_to_server(servicer, server):
35
+ rpc_method_handlers = {
36
+ "Embed": grpc.unary_unary_rpc_method_handler(
37
+ servicer.Embed,
38
+ request_deserializer=embeddings__pb2.EmbeddingRequest.FromString,
39
+ response_serializer=embeddings__pb2.BatchEmbeds.SerializeToString,
40
+ ),
41
+ }
42
+ generic_handler = grpc.method_handlers_generic_handler(
43
+ "Embedding", rpc_method_handlers
44
+ )
45
+ server.add_generic_rpc_handlers((generic_handler,))
46
+
47
+
48
+ # This class is part of an EXPERIMENTAL API.
49
+ class Embedding(object):
50
+ """Missing associated documentation comment in .proto file."""
51
+
52
+ @staticmethod
53
+ def Embed(
54
+ request,
55
+ target,
56
+ options=(),
57
+ channel_credentials=None,
58
+ call_credentials=None,
59
+ insecure=False,
60
+ compression=None,
61
+ wait_for_ready=None,
62
+ timeout=None,
63
+ metadata=None,
64
+ ):
65
+ return grpc.experimental.unary_unary(
66
+ request,
67
+ target,
68
+ "/Embedding/Embed",
69
+ embeddings__pb2.EmbeddingRequest.SerializeToString,
70
+ embeddings__pb2.BatchEmbeds.FromString,
71
+ options,
72
+ channel_credentials,
73
+ insecure,
74
+ call_credentials,
75
+ compression,
76
+ wait_for_ready,
77
+ timeout,
78
+ metadata,
79
+ )