langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +70 -0
- langroid/agent/__init__.py +22 -0
- langroid/agent/base.py +120 -33
- langroid/agent/batch.py +134 -35
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +164 -100
- langroid/agent/chat_document.py +19 -2
- langroid/agent/openai_assistant.py +20 -10
- langroid/agent/special/__init__.py +33 -10
- langroid/agent/special/doc_chat_agent.py +521 -108
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +23 -7
- langroid/agent/special/retriever_agent.py +29 -174
- langroid/agent/special/sql/__init__.py +7 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +11 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +423 -114
- langroid/agent/tool_message.py +67 -10
- langroid/agent/tools/__init__.py +8 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +6 -24
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/cachedb/__init__.py +6 -0
- langroid/embedding_models/__init__.py +24 -0
- langroid/embedding_models/base.py +9 -1
- langroid/embedding_models/models.py +117 -17
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +22 -0
- langroid/language_models/azure_openai.py +47 -4
- langroid/language_models/base.py +26 -10
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_gpt.py +407 -121
- langroid/language_models/prompt_formatter/__init__.py +9 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +10 -9
- langroid/mytypes.py +10 -4
- langroid/parsing/__init__.py +33 -1
- langroid/parsing/document_parser.py +259 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +20 -7
- langroid/parsing/repo_loader.py +108 -46
- langroid/parsing/search.py +8 -0
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -13
- langroid/parsing/urls.py +18 -9
- langroid/parsing/utils.py +130 -9
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +7 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +10 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/configuration.py +0 -1
- langroid/utils/constants.py +4 -0
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +15 -2
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +446 -4
- langroid/utils/system.py +36 -1
- langroid/vector_store/__init__.py +34 -2
- langroid/vector_store/base.py +33 -2
- langroid/vector_store/chromadb.py +42 -13
- langroid/vector_store/lancedb.py +226 -60
- langroid/vector_store/meilisearch.py +7 -6
- langroid/vector_store/momento.py +3 -2
- langroid/vector_store/qdrantdb.py +82 -11
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
- langroid-0.1.219.dist-info/RECORD +127 -0
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.139.dist-info/RECORD +0 -103
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
langroid/agent/tool_message.py
CHANGED
@@ -6,6 +6,8 @@ an agent. The messages could represent, for example:
|
|
6
6
|
- request to run a method of the agent
|
7
7
|
"""
|
8
8
|
|
9
|
+
import json
|
10
|
+
import textwrap
|
9
11
|
from abc import ABC
|
10
12
|
from random import choice
|
11
13
|
from typing import Any, Dict, List, Type
|
@@ -14,16 +16,10 @@ from docstring_parser import parse
|
|
14
16
|
from pydantic import BaseModel
|
15
17
|
|
16
18
|
from langroid.language_models.base import LLMFunctionSpec
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
if isinstance(d, dict):
|
22
|
-
for key in list(d.keys()):
|
23
|
-
if key == k and "type" in d.keys():
|
24
|
-
del d[key]
|
25
|
-
else:
|
26
|
-
_recursive_purge_dict_key(d[key], k)
|
19
|
+
from langroid.utils.pydantic_utils import (
|
20
|
+
_recursive_purge_dict_key,
|
21
|
+
generate_simple_schema,
|
22
|
+
)
|
27
23
|
|
28
24
|
|
29
25
|
class ToolMessage(ABC, BaseModel):
|
@@ -86,6 +82,9 @@ class ToolMessage(ABC, BaseModel):
|
|
86
82
|
ex = choice(cls.examples())
|
87
83
|
return ex.json_example()
|
88
84
|
|
85
|
+
def to_json(self) -> str:
|
86
|
+
return self.json(indent=4, exclude={"result", "purpose"})
|
87
|
+
|
89
88
|
def json_example(self) -> str:
|
90
89
|
return self.json(indent=4, exclude={"result", "purpose"})
|
91
90
|
|
@@ -107,6 +106,53 @@ class ToolMessage(ABC, BaseModel):
|
|
107
106
|
properties = schema["properties"]
|
108
107
|
return properties.get(f, {}).get("default", None)
|
109
108
|
|
109
|
+
@classmethod
|
110
|
+
def json_instructions(cls, tool: bool = False) -> str:
|
111
|
+
"""
|
112
|
+
Default Instructions to the LLM showing how to use the tool/function-call.
|
113
|
+
Works for GPT4 but override this for weaker LLMs if needed.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
tool: instructions for Langroid-native tool use? (e.g. for non-OpenAI LLM)
|
117
|
+
(or else it would be for OpenAI Function calls)
|
118
|
+
Returns:
|
119
|
+
str: instructions on how to use the message
|
120
|
+
"""
|
121
|
+
# TODO: when we attempt to use a "simpler schema"
|
122
|
+
# (i.e. all nested fields explicit without definitions),
|
123
|
+
# we seem to get worse results, so we turn it off for now
|
124
|
+
param_dict = (
|
125
|
+
# cls.simple_schema() if tool else
|
126
|
+
cls.llm_function_schema(request=True).parameters
|
127
|
+
)
|
128
|
+
return textwrap.dedent(
|
129
|
+
f"""
|
130
|
+
TOOL: {cls.default_value("request")}
|
131
|
+
PURPOSE: {cls.default_value("purpose")}
|
132
|
+
JSON FORMAT: {
|
133
|
+
json.dumps(param_dict, indent=4)
|
134
|
+
}
|
135
|
+
{"EXAMPLE: " + cls.usage_example() if cls.examples() else ""}
|
136
|
+
""".lstrip()
|
137
|
+
)
|
138
|
+
|
139
|
+
@staticmethod
|
140
|
+
def json_group_instructions() -> str:
|
141
|
+
"""Template for instructions for a group of tools.
|
142
|
+
Works with GPT4 but override this for weaker LLMs if needed.
|
143
|
+
"""
|
144
|
+
return textwrap.dedent(
|
145
|
+
"""
|
146
|
+
=== ALL AVAILABLE TOOLS and THEIR JSON FORMAT INSTRUCTIONS ===
|
147
|
+
You have access to the following TOOLS to accomplish your task:
|
148
|
+
|
149
|
+
{json_instructions}
|
150
|
+
|
151
|
+
When one of the above TOOLs is applicable, you must express your
|
152
|
+
request as "TOOL:" followed by the request in the above JSON format.
|
153
|
+
"""
|
154
|
+
)
|
155
|
+
|
110
156
|
@classmethod
|
111
157
|
def llm_function_schema(
|
112
158
|
cls,
|
@@ -178,3 +224,14 @@ class ToolMessage(ABC, BaseModel):
|
|
178
224
|
description=cls.default_value("purpose"),
|
179
225
|
parameters=parameters,
|
180
226
|
)
|
227
|
+
|
228
|
+
@classmethod
|
229
|
+
def simple_schema(cls) -> Dict[str, Any]:
|
230
|
+
"""
|
231
|
+
Return a simplified schema for the message, with only the request and
|
232
|
+
required fields.
|
233
|
+
Returns:
|
234
|
+
Dict[str, Any]: simplified schema
|
235
|
+
"""
|
236
|
+
schema = generate_simple_schema(cls, exclude=["result", "purpose"])
|
237
|
+
return schema
|
langroid/agent/tools/__init__.py
CHANGED
@@ -0,0 +1,66 @@
|
|
1
|
+
"""
|
2
|
+
A tool to trigger a Metaphor search for a given query,
|
3
|
+
(https://docs.exa.ai/reference/getting-started)
|
4
|
+
and return the top results with their titles, links, summaries.
|
5
|
+
Since the tool is stateless (i.e. does not need
|
6
|
+
access to agent state), it can be enabled for any agent, without having to define a
|
7
|
+
special method inside the agent: `agent.enable_message(MetaphorSearchTool)`
|
8
|
+
|
9
|
+
NOTE: To use this tool, you need to:
|
10
|
+
|
11
|
+
* set the METAPHOR_API_KEY environment variables in
|
12
|
+
your `.env` file, e.g. `METAPHOR_API_KEY=your_api_key_here`
|
13
|
+
(Note as of 28 Jan 2023, Metaphor renamed to Exa, so you can also use
|
14
|
+
`EXA_API_KEY=your_api_key_here`)
|
15
|
+
|
16
|
+
* install langroid with the `metaphor` extra, e.g.
|
17
|
+
`pip install langroid[metaphor]` or `poetry add langroid[metaphor]`
|
18
|
+
(it installs the `metaphor-python` package from pypi).
|
19
|
+
|
20
|
+
For more information, please refer to the official docs:
|
21
|
+
https://metaphor.systems/
|
22
|
+
"""
|
23
|
+
|
24
|
+
from typing import List
|
25
|
+
|
26
|
+
from langroid.agent.tool_message import ToolMessage
|
27
|
+
from langroid.parsing.web_search import duckduckgo_search
|
28
|
+
|
29
|
+
|
30
|
+
class DuckduckgoSearchTool(ToolMessage):
|
31
|
+
request: str = "duckduckgo_search"
|
32
|
+
purpose: str = """
|
33
|
+
To search the web and return up to <num_results>
|
34
|
+
links relevant to the given <query>. When using this tool,
|
35
|
+
ONLY show the required JSON, DO NOT SAY ANYTHING ELSE.
|
36
|
+
Wait for the results of the web search, and then use them to
|
37
|
+
compose your response.
|
38
|
+
"""
|
39
|
+
query: str
|
40
|
+
num_results: int
|
41
|
+
|
42
|
+
def handle(self) -> str:
|
43
|
+
"""
|
44
|
+
Conducts a search using the metaphor API based on the provided query
|
45
|
+
and number of results by triggering a metaphor_search.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
str: A formatted string containing the titles, links, and
|
49
|
+
summaries of each search result, separated by two newlines.
|
50
|
+
"""
|
51
|
+
search_results = duckduckgo_search(self.query, self.num_results)
|
52
|
+
# return Title, Link, Summary of each result, separated by two newlines
|
53
|
+
results_str = "\n\n".join(str(result) for result in search_results)
|
54
|
+
return f"""
|
55
|
+
BELOW ARE THE RESULTS FROM THE WEB SEARCH. USE THESE TO COMPOSE YOUR RESPONSE:
|
56
|
+
{results_str}
|
57
|
+
"""
|
58
|
+
|
59
|
+
@classmethod
|
60
|
+
def examples(cls) -> List["ToolMessage"]:
|
61
|
+
return [
|
62
|
+
cls(
|
63
|
+
query="When was the Llama2 Large Language Model (LLM) released?",
|
64
|
+
num_results=3,
|
65
|
+
),
|
66
|
+
]
|
@@ -9,6 +9,8 @@ environment variables in your `.env` file, as explained in the
|
|
9
9
|
[README](https://github.com/langroid/langroid#gear-installation-and-setup).
|
10
10
|
"""
|
11
11
|
|
12
|
+
from typing import List
|
13
|
+
|
12
14
|
from langroid.agent.tool_message import ToolMessage
|
13
15
|
from langroid.parsing.web_search import google_search
|
14
16
|
|
@@ -26,3 +28,12 @@ class GoogleSearchTool(ToolMessage):
|
|
26
28
|
search_results = google_search(self.query, self.num_results)
|
27
29
|
# return Title, Link, Summary of each result, separated by two newlines
|
28
30
|
return "\n\n".join(str(result) for result in search_results)
|
31
|
+
|
32
|
+
@classmethod
|
33
|
+
def examples(cls) -> List["ToolMessage"]:
|
34
|
+
return [
|
35
|
+
cls(
|
36
|
+
query="When was the Llama2 Large Language Model (LLM) released?",
|
37
|
+
num_results=3,
|
38
|
+
),
|
39
|
+
]
|
@@ -0,0 +1,67 @@
|
|
1
|
+
"""
|
2
|
+
A tool to trigger a Metaphor search for a given query,
|
3
|
+
(https://docs.exa.ai/reference/getting-started)
|
4
|
+
and return the top results with their titles, links, summaries.
|
5
|
+
Since the tool is stateless (i.e. does not need
|
6
|
+
access to agent state), it can be enabled for any agent, without having to define a
|
7
|
+
special method inside the agent: `agent.enable_message(MetaphorSearchTool)`
|
8
|
+
|
9
|
+
NOTE: To use this tool, you need to:
|
10
|
+
|
11
|
+
* set the METAPHOR_API_KEY environment variables in
|
12
|
+
your `.env` file, e.g. `METAPHOR_API_KEY=your_api_key_here`
|
13
|
+
(Note as of 28 Jan 2023, Metaphor renamed to Exa, so you can also use
|
14
|
+
`EXA_API_KEY=your_api_key_here`)
|
15
|
+
|
16
|
+
* install langroid with the `metaphor` extra, e.g.
|
17
|
+
`pip install langroid[metaphor]` or `poetry add langroid[metaphor]`
|
18
|
+
(it installs the `metaphor-python` package from pypi).
|
19
|
+
|
20
|
+
For more information, please refer to the official docs:
|
21
|
+
https://metaphor.systems/
|
22
|
+
"""
|
23
|
+
|
24
|
+
from typing import List
|
25
|
+
|
26
|
+
from langroid.agent.tool_message import ToolMessage
|
27
|
+
from langroid.parsing.web_search import metaphor_search
|
28
|
+
|
29
|
+
|
30
|
+
class MetaphorSearchTool(ToolMessage):
|
31
|
+
request: str = "metaphor_search"
|
32
|
+
purpose: str = """
|
33
|
+
To search the web and return up to <num_results>
|
34
|
+
links relevant to the given <query>. When using this tool,
|
35
|
+
ONLY show the required JSON, DO NOT SAY ANYTHING ELSE.
|
36
|
+
Wait for the results of the web search, and then use them to
|
37
|
+
compose your response.
|
38
|
+
"""
|
39
|
+
query: str
|
40
|
+
num_results: int
|
41
|
+
|
42
|
+
def handle(self) -> str:
|
43
|
+
"""
|
44
|
+
Conducts a search using the metaphor API based on the provided query
|
45
|
+
and number of results by triggering a metaphor_search.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
str: A formatted string containing the titles, links, and
|
49
|
+
summaries of each search result, separated by two newlines.
|
50
|
+
"""
|
51
|
+
|
52
|
+
search_results = metaphor_search(self.query, self.num_results)
|
53
|
+
# return Title, Link, Summary of each result, separated by two newlines
|
54
|
+
results_str = "\n\n".join(str(result) for result in search_results)
|
55
|
+
return f"""
|
56
|
+
BELOW ARE THE RESULTS FROM THE WEB SEARCH. USE THESE TO COMPOSE YOUR RESPONSE:
|
57
|
+
{results_str}
|
58
|
+
"""
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def examples(cls) -> List["ToolMessage"]:
|
62
|
+
return [
|
63
|
+
cls(
|
64
|
+
query="When was the Llama2 Large Language Model (LLM) released?",
|
65
|
+
num_results=3,
|
66
|
+
),
|
67
|
+
]
|
@@ -6,25 +6,8 @@ the method `_get_tool_list()`).
|
|
6
6
|
|
7
7
|
See usage examples in `tests/main/test_multi_agent_complex.py` and
|
8
8
|
`tests/main/test_recipient_tool.py`.
|
9
|
-
|
10
|
-
Previously we were using RecipientValidatorAgent to enforce proper
|
11
|
-
recipient specifiction, but the preferred method is to use the
|
12
|
-
`RecipientTool` class. This has numerous advantages:
|
13
|
-
- it uses the tool/function-call mechanism to specify a recipient in a JSON-structured
|
14
|
-
string, which is more consistent with the rest of the system, and does not require
|
15
|
-
inventing a new syntax like `TO:<recipient>` (which the RecipientValidatorAgent
|
16
|
-
uses).
|
17
|
-
- it removes the need for any special parsing of the message content, since we leverage
|
18
|
-
the built-in JSON tool-matching in `Agent.handle_message()` and downstream code.
|
19
|
-
- it does not require setting the `parent_responder` field in the `ChatDocument`
|
20
|
-
metadata, which is somewhat hacky.
|
21
|
-
- it appears to be less brittle than requiring the LLM to use TO:<recipient> syntax:
|
22
|
-
The LLM almost never forgets to use the RecipientTool as instructed.
|
23
|
-
- The RecipientTool class acts as a specification of the required syntax, and also
|
24
|
-
contains mechanisms to enforce this syntax.
|
25
|
-
- For a developer who needs to enforce recipient specification for an agent, they only
|
26
|
-
need to do `agent.enable_message(RecipientTool)`, and the rest is taken care of.
|
27
9
|
"""
|
10
|
+
|
28
11
|
from typing import List, Type
|
29
12
|
|
30
13
|
from rich import print
|
@@ -68,17 +51,16 @@ class AddRecipientTool(ToolMessage):
|
|
68
51
|
)
|
69
52
|
if self.__class__.saved_content == "":
|
70
53
|
recipient_request_name = RecipientTool.default_value("request")
|
71
|
-
|
72
|
-
f"""
|
54
|
+
content = f"""
|
73
55
|
Recipient specified but content is empty!
|
74
56
|
This could be because the `{self.request}` tool/function was used
|
75
57
|
before using `{recipient_request_name}` tool/function.
|
76
58
|
Resend the message using `{recipient_request_name}` tool/function.
|
77
59
|
"""
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
60
|
+
else:
|
61
|
+
content = self.__class__.saved_content # use class-level attrib value
|
62
|
+
# erase content since we just used it.
|
63
|
+
self.__class__.saved_content = ""
|
82
64
|
return ChatDocument(
|
83
65
|
content=content,
|
84
66
|
metadata=ChatDocMetaData(
|
@@ -0,0 +1,79 @@
|
|
1
|
+
"""
|
2
|
+
A tool which returns a Search RAG response from the SciPhi API.
|
3
|
+
their titles, links, summaries. Since the tool is stateless (i.e. does not need
|
4
|
+
access to agent state), it can be enabled for any agent, without having to define a
|
5
|
+
special method inside the agent: `agent.enable_message(SciPhiSearchRAGTool)`
|
6
|
+
|
7
|
+
Example return output appears as follows below:
|
8
|
+
|
9
|
+
<-- Query -->
|
10
|
+
```
|
11
|
+
Find 3 results on the internet about the LK-99 superconducting material.
|
12
|
+
``
|
13
|
+
|
14
|
+
<-- Response (compressed for this example)-->
|
15
|
+
```
|
16
|
+
[ result1 ]
|
17
|
+
|
18
|
+
[ result2 ]
|
19
|
+
|
20
|
+
[ result3 ]
|
21
|
+
|
22
|
+
```
|
23
|
+
|
24
|
+
NOTE: Using this tool requires getting an API key from sciphi.ai.
|
25
|
+
Setup is as simple as shown below
|
26
|
+
# Get a free API key at https://www.sciphi.ai/account
|
27
|
+
# export SCIPHI_API_KEY=$MY_SCIPHI_API_KEY before running the agent
|
28
|
+
# OR add SCIPHI_API_KEY=$MY_SCIPHI_API_KEY to your .env file
|
29
|
+
|
30
|
+
This tool requires installing langroid with the `sciphi` extra, e.g.
|
31
|
+
`pip install langroid[sciphi]` or `poetry add langroid[sciphi]`
|
32
|
+
(it installs the `agent-search` package from pypi).
|
33
|
+
|
34
|
+
For more information, please refer to the official docs:
|
35
|
+
https://agent-search.readthedocs.io/en/latest/
|
36
|
+
"""
|
37
|
+
|
38
|
+
from typing import List
|
39
|
+
|
40
|
+
try:
|
41
|
+
from agent_search import SciPhi
|
42
|
+
except ImportError:
|
43
|
+
raise ImportError(
|
44
|
+
"You are attempting to use the `agent-search` library;"
|
45
|
+
"To use it, please install langroid with the `sciphi` extra, e.g. "
|
46
|
+
"`pip install langroid[sciphi]` or `poetry add langroid[sciphi]` "
|
47
|
+
"(it installs the `agent-search` package from pypi)."
|
48
|
+
)
|
49
|
+
|
50
|
+
from langroid.agent.tool_message import ToolMessage
|
51
|
+
|
52
|
+
|
53
|
+
class SciPhiSearchRAGTool(ToolMessage):
|
54
|
+
request: str = "web_search_rag"
|
55
|
+
purpose: str = """
|
56
|
+
To search the web with provider <search_provider> and
|
57
|
+
return a response summary with llm model <llm_model> the given <query>.
|
58
|
+
"""
|
59
|
+
query: str
|
60
|
+
|
61
|
+
def handle(self) -> str:
|
62
|
+
rag_response = SciPhi().get_search_rag_response(
|
63
|
+
query=self.query, search_provider="bing", llm_model="SciPhi/Sensei-7B-V1"
|
64
|
+
)
|
65
|
+
result = rag_response["response"]
|
66
|
+
result = (
|
67
|
+
f"### RAG Response:\n{result}\n\n"
|
68
|
+
+ "### Related Queries:\n"
|
69
|
+
+ "\n".join(rag_response["related_queries"])
|
70
|
+
)
|
71
|
+
return result # type: ignore
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def examples(cls) -> List["ToolMessage"]:
|
75
|
+
return [
|
76
|
+
cls(
|
77
|
+
query="When was the Llama2 Large Language Model (LLM) released?",
|
78
|
+
),
|
79
|
+
]
|
langroid/cachedb/__init__.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
from . import base
|
2
2
|
from . import models
|
3
|
+
from . import remote_embeds
|
3
4
|
|
5
|
+
from .base import (
|
6
|
+
EmbeddingModel,
|
7
|
+
EmbeddingModelsConfig,
|
8
|
+
)
|
4
9
|
from .models import (
|
5
10
|
OpenAIEmbeddings,
|
6
11
|
OpenAIEmbeddingsConfig,
|
@@ -8,3 +13,22 @@ from .models import (
|
|
8
13
|
SentenceTransformerEmbeddings,
|
9
14
|
embedding_model,
|
10
15
|
)
|
16
|
+
from .remote_embeds import (
|
17
|
+
RemoteEmbeddingsConfig,
|
18
|
+
RemoteEmbeddings,
|
19
|
+
)
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
"base",
|
23
|
+
"models",
|
24
|
+
"remote_embeds",
|
25
|
+
"EmbeddingModel",
|
26
|
+
"EmbeddingModelsConfig",
|
27
|
+
"OpenAIEmbeddings",
|
28
|
+
"OpenAIEmbeddingsConfig",
|
29
|
+
"SentenceTransformerEmbeddingsConfig",
|
30
|
+
"SentenceTransformerEmbeddings",
|
31
|
+
"embedding_model",
|
32
|
+
"RemoteEmbeddingsConfig",
|
33
|
+
"RemoteEmbeddings",
|
34
|
+
]
|
@@ -12,6 +12,8 @@ logging.getLogger("openai").setLevel(logging.ERROR)
|
|
12
12
|
class EmbeddingModelsConfig(BaseSettings):
|
13
13
|
model_type: str = "openai"
|
14
14
|
dims: int = 0
|
15
|
+
context_length: int = 512
|
16
|
+
batch_size: int = 512
|
15
17
|
|
16
18
|
|
17
19
|
class EmbeddingModel(ABC):
|
@@ -27,8 +29,14 @@ class EmbeddingModel(ABC):
|
|
27
29
|
SentenceTransformerEmbeddings,
|
28
30
|
SentenceTransformerEmbeddingsConfig,
|
29
31
|
)
|
32
|
+
from langroid.embedding_models.remote_embeds import (
|
33
|
+
RemoteEmbeddings,
|
34
|
+
RemoteEmbeddingsConfig,
|
35
|
+
)
|
30
36
|
|
31
|
-
if isinstance(config,
|
37
|
+
if isinstance(config, RemoteEmbeddingsConfig):
|
38
|
+
return RemoteEmbeddings(config)
|
39
|
+
elif isinstance(config, OpenAIEmbeddingsConfig):
|
32
40
|
return OpenAIEmbeddings(config)
|
33
41
|
elif isinstance(config, SentenceTransformerEmbeddingsConfig):
|
34
42
|
return SentenceTransformerEmbeddings(config)
|
@@ -1,29 +1,93 @@
|
|
1
|
+
import atexit
|
1
2
|
import os
|
2
|
-
from typing import Callable, List
|
3
|
+
from typing import Callable, List, Optional
|
3
4
|
|
5
|
+
import tiktoken
|
4
6
|
from dotenv import load_dotenv
|
5
7
|
from openai import OpenAI
|
6
8
|
|
7
9
|
from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
|
8
|
-
from langroid.language_models.utils import retry_with_exponential_backoff
|
9
10
|
from langroid.mytypes import Embeddings
|
11
|
+
from langroid.parsing.utils import batched
|
10
12
|
|
11
13
|
|
12
14
|
class OpenAIEmbeddingsConfig(EmbeddingModelsConfig):
|
13
15
|
model_type: str = "openai"
|
14
16
|
model_name: str = "text-embedding-ada-002"
|
15
17
|
api_key: str = ""
|
18
|
+
api_base: Optional[str] = None
|
16
19
|
organization: str = ""
|
17
20
|
dims: int = 1536
|
21
|
+
context_length: int = 8192
|
18
22
|
|
19
23
|
|
20
24
|
class SentenceTransformerEmbeddingsConfig(EmbeddingModelsConfig):
|
21
25
|
model_type: str = "sentence-transformer"
|
22
26
|
model_name: str = "BAAI/bge-large-en-v1.5"
|
27
|
+
context_length: int = 512
|
28
|
+
data_parallel: bool = False
|
29
|
+
# Select device (e.g. "cuda", "cpu") when data parallel is disabled
|
30
|
+
device: Optional[str] = None
|
31
|
+
# Select devices when data parallel is enabled
|
32
|
+
devices: Optional[list[str]] = None
|
33
|
+
|
34
|
+
|
35
|
+
class EmbeddingFunctionCallable:
|
36
|
+
"""
|
37
|
+
A callable class designed to generate embeddings for a list of texts using
|
38
|
+
the OpenAI API, with automatic retries on failure.
|
39
|
+
|
40
|
+
Attributes:
|
41
|
+
model (OpenAIEmbeddings): An instance of OpenAIEmbeddings that provides
|
42
|
+
configuration and utilities for generating embeddings.
|
43
|
+
|
44
|
+
Methods:
|
45
|
+
__call__(input: List[str]) -> Embeddings: Generate embeddings for
|
46
|
+
a list of input texts.
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(self, model: "OpenAIEmbeddings", batch_size: int = 512):
|
50
|
+
"""
|
51
|
+
Initialize the EmbeddingFunctionCallable with a specific model.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
model (OpenAIEmbeddings): An instance of OpenAIEmbeddings to use for
|
55
|
+
generating embeddings.
|
56
|
+
batch_size (int): Batch size
|
57
|
+
"""
|
58
|
+
self.model = model
|
59
|
+
self.batch_size = batch_size
|
60
|
+
|
61
|
+
def __call__(self, input: List[str]) -> Embeddings:
|
62
|
+
"""
|
63
|
+
Generate embeddings for a given list of input texts using the OpenAI API,
|
64
|
+
with retries on failure.
|
65
|
+
|
66
|
+
This method:
|
67
|
+
- Truncates each text in the input list to the model's maximum context length.
|
68
|
+
- Processes the texts in batches to generate embeddings efficiently.
|
69
|
+
- Automatically retries the embedding generation process with exponential
|
70
|
+
backoff in case of failures.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
input (List[str]): A list of input texts to generate embeddings for.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
Embeddings: A list of embedding vectors corresponding to the input texts.
|
77
|
+
"""
|
78
|
+
tokenized_texts = self.model.truncate_texts(input)
|
79
|
+
embeds = []
|
80
|
+
for batch in batched(tokenized_texts, self.batch_size):
|
81
|
+
result = self.model.client.embeddings.create(
|
82
|
+
input=batch, model=self.model.config.model_name
|
83
|
+
)
|
84
|
+
batch_embeds = [d.embedding for d in result.data]
|
85
|
+
embeds.extend(batch_embeds)
|
86
|
+
return embeds
|
23
87
|
|
24
88
|
|
25
89
|
class OpenAIEmbeddings(EmbeddingModel):
|
26
|
-
def __init__(self, config: OpenAIEmbeddingsConfig):
|
90
|
+
def __init__(self, config: OpenAIEmbeddingsConfig = OpenAIEmbeddingsConfig()):
|
27
91
|
super().__init__()
|
28
92
|
self.config = config
|
29
93
|
load_dotenv()
|
@@ -36,28 +100,38 @@ class OpenAIEmbeddings(EmbeddingModel):
|
|
36
100
|
in your .env file.
|
37
101
|
"""
|
38
102
|
)
|
39
|
-
self.client = OpenAI(api_key=self.config.api_key)
|
103
|
+
self.client = OpenAI(base_url=self.config.api_base, api_key=self.config.api_key)
|
104
|
+
self.tokenizer = tiktoken.encoding_for_model(self.config.model_name)
|
105
|
+
|
106
|
+
def truncate_texts(self, texts: List[str]) -> List[List[int]]:
|
107
|
+
"""
|
108
|
+
Truncate texts to the embedding model's context length.
|
109
|
+
TODO: Maybe we should show warning, and consider doing T5 summarization?
|
110
|
+
"""
|
111
|
+
return [
|
112
|
+
self.tokenizer.encode(text, disallowed_special=())[
|
113
|
+
: self.config.context_length
|
114
|
+
]
|
115
|
+
for text in texts
|
116
|
+
]
|
40
117
|
|
41
118
|
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
42
|
-
|
43
|
-
def fn(texts: List[str]) -> Embeddings:
|
44
|
-
result = self.client.embeddings.create(
|
45
|
-
input=texts, model=self.config.model_name
|
46
|
-
)
|
47
|
-
return [d.embedding for d in result.data]
|
48
|
-
|
49
|
-
return fn
|
119
|
+
return EmbeddingFunctionCallable(self, self.config.batch_size)
|
50
120
|
|
51
121
|
@property
|
52
122
|
def embedding_dims(self) -> int:
|
53
123
|
return self.config.dims
|
54
124
|
|
55
125
|
|
126
|
+
STEC = SentenceTransformerEmbeddingsConfig
|
127
|
+
|
128
|
+
|
56
129
|
class SentenceTransformerEmbeddings(EmbeddingModel):
|
57
|
-
def __init__(self, config:
|
130
|
+
def __init__(self, config: STEC = STEC()):
|
58
131
|
# this is an "extra" optional dependency, so we import it here
|
59
132
|
try:
|
60
133
|
from sentence_transformers import SentenceTransformer
|
134
|
+
from transformers import AutoTokenizer
|
61
135
|
except ImportError:
|
62
136
|
raise ImportError(
|
63
137
|
"""
|
@@ -69,13 +143,39 @@ class SentenceTransformerEmbeddings(EmbeddingModel):
|
|
69
143
|
|
70
144
|
super().__init__()
|
71
145
|
self.config = config
|
72
|
-
|
146
|
+
|
147
|
+
self.model = SentenceTransformer(
|
148
|
+
self.config.model_name,
|
149
|
+
device=self.config.device,
|
150
|
+
)
|
151
|
+
if self.config.data_parallel:
|
152
|
+
self.pool = self.model.start_multi_process_pool(
|
153
|
+
self.config.devices # type: ignore
|
154
|
+
)
|
155
|
+
atexit.register(
|
156
|
+
lambda: SentenceTransformer.stop_multi_process_pool(self.pool)
|
157
|
+
)
|
158
|
+
|
159
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
|
160
|
+
self.config.context_length = self.tokenizer.model_max_length
|
73
161
|
|
74
162
|
def embedding_fn(self) -> Callable[[List[str]], Embeddings]:
|
75
163
|
def fn(texts: List[str]) -> Embeddings:
|
76
|
-
|
77
|
-
|
78
|
-
|
164
|
+
if self.config.data_parallel:
|
165
|
+
embeds: Embeddings = self.model.encode_multi_process(
|
166
|
+
texts,
|
167
|
+
self.pool,
|
168
|
+
batch_size=self.config.batch_size,
|
169
|
+
).tolist()
|
170
|
+
else:
|
171
|
+
embeds = []
|
172
|
+
for batch in batched(texts, self.config.batch_size):
|
173
|
+
batch_embeds = self.model.encode(
|
174
|
+
batch, convert_to_numpy=True
|
175
|
+
).tolist() # type: ignore
|
176
|
+
embeds.extend(batch_embeds)
|
177
|
+
|
178
|
+
return embeds
|
79
179
|
|
80
180
|
return fn
|
81
181
|
|
@@ -0,0 +1,19 @@
|
|
1
|
+
syntax = "proto3";
|
2
|
+
|
3
|
+
service Embedding {
|
4
|
+
rpc Embed (EmbeddingRequest) returns (BatchEmbeds) {};
|
5
|
+
}
|
6
|
+
|
7
|
+
message EmbeddingRequest {
|
8
|
+
string model_name = 1;
|
9
|
+
int32 batch_size = 2;
|
10
|
+
repeated string strings = 3;
|
11
|
+
}
|
12
|
+
|
13
|
+
message BatchEmbeds {
|
14
|
+
repeated Embed embeds = 1;
|
15
|
+
}
|
16
|
+
|
17
|
+
message Embed {
|
18
|
+
repeated float embed = 1;
|
19
|
+
}
|