langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. langroid/__init__.py +70 -0
  2. langroid/agent/__init__.py +22 -0
  3. langroid/agent/base.py +120 -33
  4. langroid/agent/batch.py +134 -35
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +608 -0
  7. langroid/agent/chat_agent.py +164 -100
  8. langroid/agent/chat_document.py +19 -2
  9. langroid/agent/openai_assistant.py +20 -10
  10. langroid/agent/special/__init__.py +33 -10
  11. langroid/agent/special/doc_chat_agent.py +521 -108
  12. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  13. langroid/agent/special/lance_rag/__init__.py +9 -0
  14. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  15. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  16. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  17. langroid/agent/special/lance_tools.py +44 -0
  18. langroid/agent/special/neo4j/__init__.py +0 -0
  19. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  20. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  21. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  22. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  23. langroid/agent/special/relevance_extractor_agent.py +23 -7
  24. langroid/agent/special/retriever_agent.py +29 -174
  25. langroid/agent/special/sql/__init__.py +7 -0
  26. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  27. langroid/agent/special/sql/utils/__init__.py +11 -0
  28. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  29. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  30. langroid/agent/special/table_chat_agent.py +43 -9
  31. langroid/agent/task.py +423 -114
  32. langroid/agent/tool_message.py +67 -10
  33. langroid/agent/tools/__init__.py +8 -0
  34. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  35. langroid/agent/tools/google_search_tool.py +11 -0
  36. langroid/agent/tools/metaphor_search_tool.py +67 -0
  37. langroid/agent/tools/recipient_tool.py +6 -24
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/cachedb/__init__.py +6 -0
  40. langroid/embedding_models/__init__.py +24 -0
  41. langroid/embedding_models/base.py +9 -1
  42. langroid/embedding_models/models.py +117 -17
  43. langroid/embedding_models/protoc/embeddings.proto +19 -0
  44. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  45. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  46. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  47. langroid/embedding_models/remote_embeds.py +153 -0
  48. langroid/language_models/__init__.py +22 -0
  49. langroid/language_models/azure_openai.py +47 -4
  50. langroid/language_models/base.py +26 -10
  51. langroid/language_models/config.py +5 -0
  52. langroid/language_models/openai_gpt.py +407 -121
  53. langroid/language_models/prompt_formatter/__init__.py +9 -0
  54. langroid/language_models/prompt_formatter/base.py +4 -6
  55. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  56. langroid/language_models/utils.py +10 -9
  57. langroid/mytypes.py +10 -4
  58. langroid/parsing/__init__.py +33 -1
  59. langroid/parsing/document_parser.py +259 -63
  60. langroid/parsing/image_text.py +32 -0
  61. langroid/parsing/parse_json.py +143 -0
  62. langroid/parsing/parser.py +20 -7
  63. langroid/parsing/repo_loader.py +108 -46
  64. langroid/parsing/search.py +8 -0
  65. langroid/parsing/table_loader.py +44 -0
  66. langroid/parsing/url_loader.py +59 -13
  67. langroid/parsing/urls.py +18 -9
  68. langroid/parsing/utils.py +130 -9
  69. langroid/parsing/web_search.py +73 -0
  70. langroid/prompts/__init__.py +7 -0
  71. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  72. langroid/prompts/prompts_config.py +1 -1
  73. langroid/utils/__init__.py +10 -0
  74. langroid/utils/algorithms/__init__.py +3 -0
  75. langroid/utils/configuration.py +0 -1
  76. langroid/utils/constants.py +4 -0
  77. langroid/utils/logging.py +2 -5
  78. langroid/utils/output/__init__.py +15 -2
  79. langroid/utils/output/status.py +33 -0
  80. langroid/utils/pandas_utils.py +30 -0
  81. langroid/utils/pydantic_utils.py +446 -4
  82. langroid/utils/system.py +36 -1
  83. langroid/vector_store/__init__.py +34 -2
  84. langroid/vector_store/base.py +33 -2
  85. langroid/vector_store/chromadb.py +42 -13
  86. langroid/vector_store/lancedb.py +226 -60
  87. langroid/vector_store/meilisearch.py +7 -6
  88. langroid/vector_store/momento.py +3 -2
  89. langroid/vector_store/qdrantdb.py +82 -11
  90. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
  91. langroid-0.1.219.dist-info/RECORD +127 -0
  92. langroid/agent/special/recipient_validator_agent.py +0 -157
  93. langroid/parsing/json.py +0 -64
  94. langroid/utils/web/selenium_login.py +0 -36
  95. langroid-0.1.139.dist-info/RECORD +0 -103
  96. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
  97. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -6,6 +6,8 @@ an agent. The messages could represent, for example:
6
6
  - request to run a method of the agent
7
7
  """
8
8
 
9
+ import json
10
+ import textwrap
9
11
  from abc import ABC
10
12
  from random import choice
11
13
  from typing import Any, Dict, List, Type
@@ -14,16 +16,10 @@ from docstring_parser import parse
14
16
  from pydantic import BaseModel
15
17
 
16
18
  from langroid.language_models.base import LLMFunctionSpec
17
-
18
-
19
- def _recursive_purge_dict_key(d: Dict[str, Any], k: str) -> None:
20
- """Remove a key from a dictionary recursively"""
21
- if isinstance(d, dict):
22
- for key in list(d.keys()):
23
- if key == k and "type" in d.keys():
24
- del d[key]
25
- else:
26
- _recursive_purge_dict_key(d[key], k)
19
+ from langroid.utils.pydantic_utils import (
20
+ _recursive_purge_dict_key,
21
+ generate_simple_schema,
22
+ )
27
23
 
28
24
 
29
25
  class ToolMessage(ABC, BaseModel):
@@ -86,6 +82,9 @@ class ToolMessage(ABC, BaseModel):
86
82
  ex = choice(cls.examples())
87
83
  return ex.json_example()
88
84
 
85
+ def to_json(self) -> str:
86
+ return self.json(indent=4, exclude={"result", "purpose"})
87
+
89
88
  def json_example(self) -> str:
90
89
  return self.json(indent=4, exclude={"result", "purpose"})
91
90
 
@@ -107,6 +106,53 @@ class ToolMessage(ABC, BaseModel):
107
106
  properties = schema["properties"]
108
107
  return properties.get(f, {}).get("default", None)
109
108
 
109
+ @classmethod
110
+ def json_instructions(cls, tool: bool = False) -> str:
111
+ """
112
+ Default Instructions to the LLM showing how to use the tool/function-call.
113
+ Works for GPT4 but override this for weaker LLMs if needed.
114
+
115
+ Args:
116
+ tool: instructions for Langroid-native tool use? (e.g. for non-OpenAI LLM)
117
+ (or else it would be for OpenAI Function calls)
118
+ Returns:
119
+ str: instructions on how to use the message
120
+ """
121
+ # TODO: when we attempt to use a "simpler schema"
122
+ # (i.e. all nested fields explicit without definitions),
123
+ # we seem to get worse results, so we turn it off for now
124
+ param_dict = (
125
+ # cls.simple_schema() if tool else
126
+ cls.llm_function_schema(request=True).parameters
127
+ )
128
+ return textwrap.dedent(
129
+ f"""
130
+ TOOL: {cls.default_value("request")}
131
+ PURPOSE: {cls.default_value("purpose")}
132
+ JSON FORMAT: {
133
+ json.dumps(param_dict, indent=4)
134
+ }
135
+ {"EXAMPLE: " + cls.usage_example() if cls.examples() else ""}
136
+ """.lstrip()
137
+ )
138
+
139
+ @staticmethod
140
+ def json_group_instructions() -> str:
141
+ """Template for instructions for a group of tools.
142
+ Works with GPT4 but override this for weaker LLMs if needed.
143
+ """
144
+ return textwrap.dedent(
145
+ """
146
+ === ALL AVAILABLE TOOLS and THEIR JSON FORMAT INSTRUCTIONS ===
147
+ You have access to the following TOOLS to accomplish your task:
148
+
149
+ {json_instructions}
150
+
151
+ When one of the above TOOLs is applicable, you must express your
152
+ request as "TOOL:" followed by the request in the above JSON format.
153
+ """
154
+ )
155
+
110
156
  @classmethod
111
157
  def llm_function_schema(
112
158
  cls,
@@ -178,3 +224,14 @@ class ToolMessage(ABC, BaseModel):
178
224
  description=cls.default_value("purpose"),
179
225
  parameters=parameters,
180
226
  )
227
+
228
+ @classmethod
229
+ def simple_schema(cls) -> Dict[str, Any]:
230
+ """
231
+ Return a simplified schema for the message, with only the request and
232
+ required fields.
233
+ Returns:
234
+ Dict[str, Any]: simplified schema
235
+ """
236
+ schema = generate_simple_schema(cls, exclude=["result", "purpose"])
237
+ return schema
@@ -3,3 +3,11 @@ from .recipient_tool import AddRecipientTool, RecipientTool
3
3
 
4
4
  from . import google_search_tool
5
5
  from . import recipient_tool
6
+
7
+ __all__ = [
8
+ "GoogleSearchTool",
9
+ "AddRecipientTool",
10
+ "RecipientTool",
11
+ "google_search_tool",
12
+ "recipient_tool",
13
+ ]
@@ -0,0 +1,66 @@
1
+ """
2
+ A tool to trigger a Metaphor search for a given query,
3
+ (https://docs.exa.ai/reference/getting-started)
4
+ and return the top results with their titles, links, summaries.
5
+ Since the tool is stateless (i.e. does not need
6
+ access to agent state), it can be enabled for any agent, without having to define a
7
+ special method inside the agent: `agent.enable_message(MetaphorSearchTool)`
8
+
9
+ NOTE: To use this tool, you need to:
10
+
11
+ * set the METAPHOR_API_KEY environment variables in
12
+ your `.env` file, e.g. `METAPHOR_API_KEY=your_api_key_here`
13
+ (Note as of 28 Jan 2023, Metaphor renamed to Exa, so you can also use
14
+ `EXA_API_KEY=your_api_key_here`)
15
+
16
+ * install langroid with the `metaphor` extra, e.g.
17
+ `pip install langroid[metaphor]` or `poetry add langroid[metaphor]`
18
+ (it installs the `metaphor-python` package from pypi).
19
+
20
+ For more information, please refer to the official docs:
21
+ https://metaphor.systems/
22
+ """
23
+
24
+ from typing import List
25
+
26
+ from langroid.agent.tool_message import ToolMessage
27
+ from langroid.parsing.web_search import duckduckgo_search
28
+
29
+
30
+ class DuckduckgoSearchTool(ToolMessage):
31
+ request: str = "duckduckgo_search"
32
+ purpose: str = """
33
+ To search the web and return up to <num_results>
34
+ links relevant to the given <query>. When using this tool,
35
+ ONLY show the required JSON, DO NOT SAY ANYTHING ELSE.
36
+ Wait for the results of the web search, and then use them to
37
+ compose your response.
38
+ """
39
+ query: str
40
+ num_results: int
41
+
42
+ def handle(self) -> str:
43
+ """
44
+ Conducts a search using the metaphor API based on the provided query
45
+ and number of results by triggering a metaphor_search.
46
+
47
+ Returns:
48
+ str: A formatted string containing the titles, links, and
49
+ summaries of each search result, separated by two newlines.
50
+ """
51
+ search_results = duckduckgo_search(self.query, self.num_results)
52
+ # return Title, Link, Summary of each result, separated by two newlines
53
+ results_str = "\n\n".join(str(result) for result in search_results)
54
+ return f"""
55
+ BELOW ARE THE RESULTS FROM THE WEB SEARCH. USE THESE TO COMPOSE YOUR RESPONSE:
56
+ {results_str}
57
+ """
58
+
59
+ @classmethod
60
+ def examples(cls) -> List["ToolMessage"]:
61
+ return [
62
+ cls(
63
+ query="When was the Llama2 Large Language Model (LLM) released?",
64
+ num_results=3,
65
+ ),
66
+ ]
@@ -9,6 +9,8 @@ environment variables in your `.env` file, as explained in the
9
9
  [README](https://github.com/langroid/langroid#gear-installation-and-setup).
10
10
  """
11
11
 
12
+ from typing import List
13
+
12
14
  from langroid.agent.tool_message import ToolMessage
13
15
  from langroid.parsing.web_search import google_search
14
16
 
@@ -26,3 +28,12 @@ class GoogleSearchTool(ToolMessage):
26
28
  search_results = google_search(self.query, self.num_results)
27
29
  # return Title, Link, Summary of each result, separated by two newlines
28
30
  return "\n\n".join(str(result) for result in search_results)
31
+
32
+ @classmethod
33
+ def examples(cls) -> List["ToolMessage"]:
34
+ return [
35
+ cls(
36
+ query="When was the Llama2 Large Language Model (LLM) released?",
37
+ num_results=3,
38
+ ),
39
+ ]
@@ -0,0 +1,67 @@
1
+ """
2
+ A tool to trigger a Metaphor search for a given query,
3
+ (https://docs.exa.ai/reference/getting-started)
4
+ and return the top results with their titles, links, summaries.
5
+ Since the tool is stateless (i.e. does not need
6
+ access to agent state), it can be enabled for any agent, without having to define a
7
+ special method inside the agent: `agent.enable_message(MetaphorSearchTool)`
8
+
9
+ NOTE: To use this tool, you need to:
10
+
11
+ * set the METAPHOR_API_KEY environment variables in
12
+ your `.env` file, e.g. `METAPHOR_API_KEY=your_api_key_here`
13
+ (Note as of 28 Jan 2023, Metaphor renamed to Exa, so you can also use
14
+ `EXA_API_KEY=your_api_key_here`)
15
+
16
+ * install langroid with the `metaphor` extra, e.g.
17
+ `pip install langroid[metaphor]` or `poetry add langroid[metaphor]`
18
+ (it installs the `metaphor-python` package from pypi).
19
+
20
+ For more information, please refer to the official docs:
21
+ https://metaphor.systems/
22
+ """
23
+
24
+ from typing import List
25
+
26
+ from langroid.agent.tool_message import ToolMessage
27
+ from langroid.parsing.web_search import metaphor_search
28
+
29
+
30
+ class MetaphorSearchTool(ToolMessage):
31
+ request: str = "metaphor_search"
32
+ purpose: str = """
33
+ To search the web and return up to <num_results>
34
+ links relevant to the given <query>. When using this tool,
35
+ ONLY show the required JSON, DO NOT SAY ANYTHING ELSE.
36
+ Wait for the results of the web search, and then use them to
37
+ compose your response.
38
+ """
39
+ query: str
40
+ num_results: int
41
+
42
+ def handle(self) -> str:
43
+ """
44
+ Conducts a search using the metaphor API based on the provided query
45
+ and number of results by triggering a metaphor_search.
46
+
47
+ Returns:
48
+ str: A formatted string containing the titles, links, and
49
+ summaries of each search result, separated by two newlines.
50
+ """
51
+
52
+ search_results = metaphor_search(self.query, self.num_results)
53
+ # return Title, Link, Summary of each result, separated by two newlines
54
+ results_str = "\n\n".join(str(result) for result in search_results)
55
+ return f"""
56
+ BELOW ARE THE RESULTS FROM THE WEB SEARCH. USE THESE TO COMPOSE YOUR RESPONSE:
57
+ {results_str}
58
+ """
59
+
60
+ @classmethod
61
+ def examples(cls) -> List["ToolMessage"]:
62
+ return [
63
+ cls(
64
+ query="When was the Llama2 Large Language Model (LLM) released?",
65
+ num_results=3,
66
+ ),
67
+ ]
@@ -6,25 +6,8 @@ the method `_get_tool_list()`).
6
6
 
7
7
  See usage examples in `tests/main/test_multi_agent_complex.py` and
8
8
  `tests/main/test_recipient_tool.py`.
9
-
10
- Previously we were using RecipientValidatorAgent to enforce proper
11
- recipient specifiction, but the preferred method is to use the
12
- `RecipientTool` class. This has numerous advantages:
13
- - it uses the tool/function-call mechanism to specify a recipient in a JSON-structured
14
- string, which is more consistent with the rest of the system, and does not require
15
- inventing a new syntax like `TO:<recipient>` (which the RecipientValidatorAgent
16
- uses).
17
- - it removes the need for any special parsing of the message content, since we leverage
18
- the built-in JSON tool-matching in `Agent.handle_message()` and downstream code.
19
- - it does not require setting the `parent_responder` field in the `ChatDocument`
20
- metadata, which is somewhat hacky.
21
- - it appears to be less brittle than requiring the LLM to use TO:<recipient> syntax:
22
- The LLM almost never forgets to use the RecipientTool as instructed.
23
- - The RecipientTool class acts as a specification of the required syntax, and also
24
- contains mechanisms to enforce this syntax.
25
- - For a developer who needs to enforce recipient specification for an agent, they only
26
- need to do `agent.enable_message(RecipientTool)`, and the rest is taken care of.
27
9
  """
10
+
28
11
  from typing import List, Type
29
12
 
30
13
  from rich import print
@@ -68,17 +51,16 @@ class AddRecipientTool(ToolMessage):
68
51
  )
69
52
  if self.__class__.saved_content == "":
70
53
  recipient_request_name = RecipientTool.default_value("request")
71
- raise ValueError(
72
- f"""
54
+ content = f"""
73
55
  Recipient specified but content is empty!
74
56
  This could be because the `{self.request}` tool/function was used
75
57
  before using `{recipient_request_name}` tool/function.
76
58
  Resend the message using `{recipient_request_name}` tool/function.
77
59
  """
78
- )
79
- content = self.__class__.saved_content # use class-level attrib value
80
- # erase content since we just used it.
81
- self.__class__.saved_content = ""
60
+ else:
61
+ content = self.__class__.saved_content # use class-level attrib value
62
+ # erase content since we just used it.
63
+ self.__class__.saved_content = ""
82
64
  return ChatDocument(
83
65
  content=content,
84
66
  metadata=ChatDocMetaData(
@@ -0,0 +1,79 @@
1
+ """
2
+ A tool which returns a Search RAG response from the SciPhi API.
3
+ their titles, links, summaries. Since the tool is stateless (i.e. does not need
4
+ access to agent state), it can be enabled for any agent, without having to define a
5
+ special method inside the agent: `agent.enable_message(SciPhiSearchRAGTool)`
6
+
7
+ Example return output appears as follows below:
8
+
9
+ <-- Query -->
10
+ ```
11
+ Find 3 results on the internet about the LK-99 superconducting material.
12
+ ``
13
+
14
+ <-- Response (compressed for this example)-->
15
+ ```
16
+ [ result1 ]
17
+
18
+ [ result2 ]
19
+
20
+ [ result3 ]
21
+
22
+ ```
23
+
24
+ NOTE: Using this tool requires getting an API key from sciphi.ai.
25
+ Setup is as simple as shown below
26
+ # Get a free API key at https://www.sciphi.ai/account
27
+ # export SCIPHI_API_KEY=$MY_SCIPHI_API_KEY before running the agent
28
+ # OR add SCIPHI_API_KEY=$MY_SCIPHI_API_KEY to your .env file
29
+
30
+ This tool requires installing langroid with the `sciphi` extra, e.g.
31
+ `pip install langroid[sciphi]` or `poetry add langroid[sciphi]`
32
+ (it installs the `agent-search` package from pypi).
33
+
34
+ For more information, please refer to the official docs:
35
+ https://agent-search.readthedocs.io/en/latest/
36
+ """
37
+
38
+ from typing import List
39
+
40
+ try:
41
+ from agent_search import SciPhi
42
+ except ImportError:
43
+ raise ImportError(
44
+ "You are attempting to use the `agent-search` library;"
45
+ "To use it, please install langroid with the `sciphi` extra, e.g. "
46
+ "`pip install langroid[sciphi]` or `poetry add langroid[sciphi]` "
47
+ "(it installs the `agent-search` package from pypi)."
48
+ )
49
+
50
+ from langroid.agent.tool_message import ToolMessage
51
+
52
+
53
+ class SciPhiSearchRAGTool(ToolMessage):
54
+ request: str = "web_search_rag"
55
+ purpose: str = """
56
+ To search the web with provider <search_provider> and
57
+ return a response summary with llm model <llm_model> the given <query>.
58
+ """
59
+ query: str
60
+
61
+ def handle(self) -> str:
62
+ rag_response = SciPhi().get_search_rag_response(
63
+ query=self.query, search_provider="bing", llm_model="SciPhi/Sensei-7B-V1"
64
+ )
65
+ result = rag_response["response"]
66
+ result = (
67
+ f"### RAG Response:\n{result}\n\n"
68
+ + "### Related Queries:\n"
69
+ + "\n".join(rag_response["related_queries"])
70
+ )
71
+ return result # type: ignore
72
+
73
+ @classmethod
74
+ def examples(cls) -> List["ToolMessage"]:
75
+ return [
76
+ cls(
77
+ query="When was the Llama2 Large Language Model (LLM) released?",
78
+ ),
79
+ ]
@@ -1,3 +1,9 @@
1
1
  from . import base
2
2
  from . import momento_cachedb
3
3
  from . import redis_cachedb
4
+
5
+ __all__ = [
6
+ "base",
7
+ "momento_cachedb",
8
+ "redis_cachedb",
9
+ ]
@@ -1,6 +1,11 @@
1
1
  from . import base
2
2
  from . import models
3
+ from . import remote_embeds
3
4
 
5
+ from .base import (
6
+ EmbeddingModel,
7
+ EmbeddingModelsConfig,
8
+ )
4
9
  from .models import (
5
10
  OpenAIEmbeddings,
6
11
  OpenAIEmbeddingsConfig,
@@ -8,3 +13,22 @@ from .models import (
8
13
  SentenceTransformerEmbeddings,
9
14
  embedding_model,
10
15
  )
16
+ from .remote_embeds import (
17
+ RemoteEmbeddingsConfig,
18
+ RemoteEmbeddings,
19
+ )
20
+
21
+ __all__ = [
22
+ "base",
23
+ "models",
24
+ "remote_embeds",
25
+ "EmbeddingModel",
26
+ "EmbeddingModelsConfig",
27
+ "OpenAIEmbeddings",
28
+ "OpenAIEmbeddingsConfig",
29
+ "SentenceTransformerEmbeddingsConfig",
30
+ "SentenceTransformerEmbeddings",
31
+ "embedding_model",
32
+ "RemoteEmbeddingsConfig",
33
+ "RemoteEmbeddings",
34
+ ]
@@ -12,6 +12,8 @@ logging.getLogger("openai").setLevel(logging.ERROR)
12
12
  class EmbeddingModelsConfig(BaseSettings):
13
13
  model_type: str = "openai"
14
14
  dims: int = 0
15
+ context_length: int = 512
16
+ batch_size: int = 512
15
17
 
16
18
 
17
19
  class EmbeddingModel(ABC):
@@ -27,8 +29,14 @@ class EmbeddingModel(ABC):
27
29
  SentenceTransformerEmbeddings,
28
30
  SentenceTransformerEmbeddingsConfig,
29
31
  )
32
+ from langroid.embedding_models.remote_embeds import (
33
+ RemoteEmbeddings,
34
+ RemoteEmbeddingsConfig,
35
+ )
30
36
 
31
- if isinstance(config, OpenAIEmbeddingsConfig):
37
+ if isinstance(config, RemoteEmbeddingsConfig):
38
+ return RemoteEmbeddings(config)
39
+ elif isinstance(config, OpenAIEmbeddingsConfig):
32
40
  return OpenAIEmbeddings(config)
33
41
  elif isinstance(config, SentenceTransformerEmbeddingsConfig):
34
42
  return SentenceTransformerEmbeddings(config)
@@ -1,29 +1,93 @@
1
+ import atexit
1
2
  import os
2
- from typing import Callable, List
3
+ from typing import Callable, List, Optional
3
4
 
5
+ import tiktoken
4
6
  from dotenv import load_dotenv
5
7
  from openai import OpenAI
6
8
 
7
9
  from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
8
- from langroid.language_models.utils import retry_with_exponential_backoff
9
10
  from langroid.mytypes import Embeddings
11
+ from langroid.parsing.utils import batched
10
12
 
11
13
 
12
14
  class OpenAIEmbeddingsConfig(EmbeddingModelsConfig):
13
15
  model_type: str = "openai"
14
16
  model_name: str = "text-embedding-ada-002"
15
17
  api_key: str = ""
18
+ api_base: Optional[str] = None
16
19
  organization: str = ""
17
20
  dims: int = 1536
21
+ context_length: int = 8192
18
22
 
19
23
 
20
24
  class SentenceTransformerEmbeddingsConfig(EmbeddingModelsConfig):
21
25
  model_type: str = "sentence-transformer"
22
26
  model_name: str = "BAAI/bge-large-en-v1.5"
27
+ context_length: int = 512
28
+ data_parallel: bool = False
29
+ # Select device (e.g. "cuda", "cpu") when data parallel is disabled
30
+ device: Optional[str] = None
31
+ # Select devices when data parallel is enabled
32
+ devices: Optional[list[str]] = None
33
+
34
+
35
+ class EmbeddingFunctionCallable:
36
+ """
37
+ A callable class designed to generate embeddings for a list of texts using
38
+ the OpenAI API, with automatic retries on failure.
39
+
40
+ Attributes:
41
+ model (OpenAIEmbeddings): An instance of OpenAIEmbeddings that provides
42
+ configuration and utilities for generating embeddings.
43
+
44
+ Methods:
45
+ __call__(input: List[str]) -> Embeddings: Generate embeddings for
46
+ a list of input texts.
47
+ """
48
+
49
+ def __init__(self, model: "OpenAIEmbeddings", batch_size: int = 512):
50
+ """
51
+ Initialize the EmbeddingFunctionCallable with a specific model.
52
+
53
+ Args:
54
+ model (OpenAIEmbeddings): An instance of OpenAIEmbeddings to use for
55
+ generating embeddings.
56
+ batch_size (int): Batch size
57
+ """
58
+ self.model = model
59
+ self.batch_size = batch_size
60
+
61
+ def __call__(self, input: List[str]) -> Embeddings:
62
+ """
63
+ Generate embeddings for a given list of input texts using the OpenAI API,
64
+ with retries on failure.
65
+
66
+ This method:
67
+ - Truncates each text in the input list to the model's maximum context length.
68
+ - Processes the texts in batches to generate embeddings efficiently.
69
+ - Automatically retries the embedding generation process with exponential
70
+ backoff in case of failures.
71
+
72
+ Args:
73
+ input (List[str]): A list of input texts to generate embeddings for.
74
+
75
+ Returns:
76
+ Embeddings: A list of embedding vectors corresponding to the input texts.
77
+ """
78
+ tokenized_texts = self.model.truncate_texts(input)
79
+ embeds = []
80
+ for batch in batched(tokenized_texts, self.batch_size):
81
+ result = self.model.client.embeddings.create(
82
+ input=batch, model=self.model.config.model_name
83
+ )
84
+ batch_embeds = [d.embedding for d in result.data]
85
+ embeds.extend(batch_embeds)
86
+ return embeds
23
87
 
24
88
 
25
89
  class OpenAIEmbeddings(EmbeddingModel):
26
- def __init__(self, config: OpenAIEmbeddingsConfig):
90
+ def __init__(self, config: OpenAIEmbeddingsConfig = OpenAIEmbeddingsConfig()):
27
91
  super().__init__()
28
92
  self.config = config
29
93
  load_dotenv()
@@ -36,28 +100,38 @@ class OpenAIEmbeddings(EmbeddingModel):
36
100
  in your .env file.
37
101
  """
38
102
  )
39
- self.client = OpenAI(api_key=self.config.api_key)
103
+ self.client = OpenAI(base_url=self.config.api_base, api_key=self.config.api_key)
104
+ self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
105
+
106
+ def truncate_texts(self, texts: List[str]) -> List[List[int]]:
107
+ """
108
+ Truncate texts to the embedding model's context length.
109
+ TODO: Maybe we should show warning, and consider doing T5 summarization?
110
+ """
111
+ return [
112
+ self.tokenizer.encode(text, disallowed_special=())[
113
+ : self.config.context_length
114
+ ]
115
+ for text in texts
116
+ ]
40
117
 
41
118
  def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
42
- @retry_with_exponential_backoff
43
- def fn(texts: List[str]) -> Embeddings:
44
- result = self.client.embeddings.create(
45
- input=texts, model=self.config.model_name
46
- )
47
- return [d.embedding for d in result.data]
48
-
49
- return fn
119
+ return EmbeddingFunctionCallable(self, self.config.batch_size)
50
120
 
51
121
  @property
52
122
  def embedding_dims(self) -> int:
53
123
  return self.config.dims
54
124
 
55
125
 
126
+ STEC = SentenceTransformerEmbeddingsConfig
127
+
128
+
56
129
  class SentenceTransformerEmbeddings(EmbeddingModel):
57
- def __init__(self, config: SentenceTransformerEmbeddingsConfig):
130
+ def __init__(self, config: STEC = STEC()):
58
131
  # this is an "extra" optional dependency, so we import it here
59
132
  try:
60
133
  from sentence_transformers import SentenceTransformer
134
+ from transformers import AutoTokenizer
61
135
  except ImportError:
62
136
  raise ImportError(
63
137
  """
@@ -69,13 +143,39 @@ class SentenceTransformerEmbeddings(EmbeddingModel):
69
143
 
70
144
  super().__init__()
71
145
  self.config = config
72
- self.model = SentenceTransformer(self.config.model_name)
146
+
147
+ self.model = SentenceTransformer(
148
+ self.config.model_name,
149
+ device=self.config.device,
150
+ )
151
+ if self.config.data_parallel:
152
+ self.pool = self.model.start_multi_process_pool(
153
+ self.config.devices # type: ignore
154
+ )
155
+ atexit.register(
156
+ lambda: SentenceTransformer.stop_multi_process_pool(self.pool)
157
+ )
158
+
159
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
160
+ self.config.context_length = self.tokenizer.model_max_length
73
161
 
74
162
  def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
75
163
  def fn(texts: List[str]) -> Embeddings:
76
- return self.model.encode( # type: ignore
77
- texts, convert_to_numpy=True
78
- ).tolist()
164
+ if self.config.data_parallel:
165
+ embeds: Embeddings = self.model.encode_multi_process(
166
+ texts,
167
+ self.pool,
168
+ batch_size=self.config.batch_size,
169
+ ).tolist()
170
+ else:
171
+ embeds = []
172
+ for batch in batched(texts, self.config.batch_size):
173
+ batch_embeds = self.model.encode(
174
+ batch, convert_to_numpy=True
175
+ ).tolist() # type: ignore
176
+ embeds.extend(batch_embeds)
177
+
178
+ return embeds
79
179
 
80
180
  return fn
81
181
 
@@ -0,0 +1,19 @@
1
+ syntax = "proto3";
2
+
3
+ service Embedding {
4
+ rpc Embed (EmbeddingRequest) returns (BatchEmbeds) {};
5
+ }
6
+
7
+ message EmbeddingRequest {
8
+ string model_name = 1;
9
+ int32 batch_size = 2;
10
+ repeated string strings = 3;
11
+ }
12
+
13
+ message BatchEmbeds {
14
+ repeated Embed embeds = 1;
15
+ }
16
+
17
+ message Embed {
18
+ repeated float embed = 1;
19
+ }