langroid 0.31.2__py3-none-any.whl → 0.33.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langroid-0.31.2.dist-info → langroid-0.33.3.dist-info}/METADATA +150 -124
- langroid-0.33.3.dist-info/RECORD +7 -0
- {langroid-0.31.2.dist-info → langroid-0.33.3.dist-info}/WHEEL +1 -1
- langroid-0.33.3.dist-info/entry_points.txt +4 -0
- pyproject.toml +317 -212
- langroid/__init__.py +0 -106
- langroid/agent/.chainlit/config.toml +0 -121
- langroid/agent/.chainlit/translations/bn.json +0 -231
- langroid/agent/.chainlit/translations/en-US.json +0 -229
- langroid/agent/.chainlit/translations/gu.json +0 -231
- langroid/agent/.chainlit/translations/he-IL.json +0 -231
- langroid/agent/.chainlit/translations/hi.json +0 -231
- langroid/agent/.chainlit/translations/kn.json +0 -231
- langroid/agent/.chainlit/translations/ml.json +0 -231
- langroid/agent/.chainlit/translations/mr.json +0 -231
- langroid/agent/.chainlit/translations/ta.json +0 -231
- langroid/agent/.chainlit/translations/te.json +0 -231
- langroid/agent/.chainlit/translations/zh-CN.json +0 -229
- langroid/agent/__init__.py +0 -41
- langroid/agent/base.py +0 -1981
- langroid/agent/batch.py +0 -398
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +0 -598
- langroid/agent/chat_agent.py +0 -1899
- langroid/agent/chat_document.py +0 -454
- langroid/agent/helpers.py +0 -0
- langroid/agent/junk +0 -13
- langroid/agent/openai_assistant.py +0 -882
- langroid/agent/special/__init__.py +0 -59
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +0 -656
- langroid/agent/special/arangodb/system_messages.py +0 -186
- langroid/agent/special/arangodb/tools.py +0 -107
- langroid/agent/special/arangodb/utils.py +0 -36
- langroid/agent/special/doc_chat_agent.py +0 -1466
- langroid/agent/special/lance_doc_chat_agent.py +0 -262
- langroid/agent/special/lance_rag/__init__.py +0 -9
- langroid/agent/special/lance_rag/critic_agent.py +0 -198
- langroid/agent/special/lance_rag/lance_rag_task.py +0 -82
- langroid/agent/special/lance_rag/query_planner_agent.py +0 -260
- langroid/agent/special/lance_tools.py +0 -61
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +0 -174
- langroid/agent/special/neo4j/neo4j_chat_agent.py +0 -433
- langroid/agent/special/neo4j/system_messages.py +0 -120
- langroid/agent/special/neo4j/tools.py +0 -32
- langroid/agent/special/relevance_extractor_agent.py +0 -127
- langroid/agent/special/retriever_agent.py +0 -56
- langroid/agent/special/sql/__init__.py +0 -17
- langroid/agent/special/sql/sql_chat_agent.py +0 -654
- langroid/agent/special/sql/utils/__init__.py +0 -21
- langroid/agent/special/sql/utils/description_extractors.py +0 -190
- langroid/agent/special/sql/utils/populate_metadata.py +0 -85
- langroid/agent/special/sql/utils/system_message.py +0 -35
- langroid/agent/special/sql/utils/tools.py +0 -64
- langroid/agent/special/table_chat_agent.py +0 -263
- langroid/agent/structured_message.py +0 -9
- langroid/agent/task.py +0 -2093
- langroid/agent/tool_message.py +0 -393
- langroid/agent/tools/__init__.py +0 -38
- langroid/agent/tools/duckduckgo_search_tool.py +0 -50
- langroid/agent/tools/file_tools.py +0 -234
- langroid/agent/tools/google_search_tool.py +0 -39
- langroid/agent/tools/metaphor_search_tool.py +0 -67
- langroid/agent/tools/orchestration.py +0 -303
- langroid/agent/tools/recipient_tool.py +0 -235
- langroid/agent/tools/retrieval_tool.py +0 -32
- langroid/agent/tools/rewind_tool.py +0 -137
- langroid/agent/tools/segment_extract_tool.py +0 -41
- langroid/agent/typed_task.py +0 -19
- langroid/agent/xml_tool_message.py +0 -382
- langroid/agent_config.py +0 -0
- langroid/cachedb/__init__.py +0 -17
- langroid/cachedb/base.py +0 -58
- langroid/cachedb/momento_cachedb.py +0 -108
- langroid/cachedb/redis_cachedb.py +0 -153
- langroid/embedding_models/__init__.py +0 -39
- langroid/embedding_models/base.py +0 -74
- langroid/embedding_models/clustering.py +0 -189
- langroid/embedding_models/models.py +0 -461
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +0 -19
- langroid/embedding_models/protoc/embeddings_pb2.py +0 -33
- langroid/embedding_models/protoc/embeddings_pb2.pyi +0 -50
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +0 -79
- langroid/embedding_models/remote_embeds.py +0 -153
- langroid/exceptions.py +0 -65
- langroid/experimental/team-save.py +0 -391
- langroid/language_models/.chainlit/config.toml +0 -121
- langroid/language_models/.chainlit/translations/en-US.json +0 -231
- langroid/language_models/__init__.py +0 -53
- langroid/language_models/azure_openai.py +0 -153
- langroid/language_models/base.py +0 -678
- langroid/language_models/config.py +0 -18
- langroid/language_models/mock_lm.py +0 -124
- langroid/language_models/openai_gpt.py +0 -1923
- langroid/language_models/prompt_formatter/__init__.py +0 -16
- langroid/language_models/prompt_formatter/base.py +0 -40
- langroid/language_models/prompt_formatter/hf_formatter.py +0 -132
- langroid/language_models/prompt_formatter/llama2_formatter.py +0 -75
- langroid/language_models/utils.py +0 -147
- langroid/mytypes.py +0 -84
- langroid/parsing/__init__.py +0 -52
- langroid/parsing/agent_chats.py +0 -38
- langroid/parsing/code-parsing.md +0 -86
- langroid/parsing/code_parser.py +0 -121
- langroid/parsing/config.py +0 -0
- langroid/parsing/document_parser.py +0 -718
- langroid/parsing/image_text.py +0 -32
- langroid/parsing/para_sentence_split.py +0 -62
- langroid/parsing/parse_json.py +0 -155
- langroid/parsing/parser.py +0 -313
- langroid/parsing/repo_loader.py +0 -790
- langroid/parsing/routing.py +0 -36
- langroid/parsing/search.py +0 -275
- langroid/parsing/spider.py +0 -102
- langroid/parsing/table_loader.py +0 -94
- langroid/parsing/url_loader.py +0 -111
- langroid/parsing/url_loader_cookies.py +0 -73
- langroid/parsing/urls.py +0 -273
- langroid/parsing/utils.py +0 -373
- langroid/parsing/web_search.py +0 -155
- langroid/prompts/__init__.py +0 -9
- langroid/prompts/chat-gpt4-system-prompt.md +0 -68
- langroid/prompts/dialog.py +0 -17
- langroid/prompts/prompts_config.py +0 -5
- langroid/prompts/templates.py +0 -141
- langroid/pydantic_v1/__init__.py +0 -10
- langroid/pydantic_v1/main.py +0 -4
- langroid/utils/.chainlit/config.toml +0 -121
- langroid/utils/.chainlit/translations/en-US.json +0 -231
- langroid/utils/__init__.py +0 -19
- langroid/utils/algorithms/__init__.py +0 -3
- langroid/utils/algorithms/graph.py +0 -103
- langroid/utils/configuration.py +0 -98
- langroid/utils/constants.py +0 -30
- langroid/utils/docker.py +0 -37
- langroid/utils/git_utils.py +0 -252
- langroid/utils/globals.py +0 -49
- langroid/utils/llms/__init__.py +0 -0
- langroid/utils/llms/strings.py +0 -8
- langroid/utils/logging.py +0 -135
- langroid/utils/object_registry.py +0 -66
- langroid/utils/output/__init__.py +0 -20
- langroid/utils/output/citations.py +0 -41
- langroid/utils/output/printing.py +0 -99
- langroid/utils/output/status.py +0 -40
- langroid/utils/pandas_utils.py +0 -30
- langroid/utils/pydantic_utils.py +0 -602
- langroid/utils/system.py +0 -286
- langroid/utils/types.py +0 -93
- langroid/utils/web/__init__.py +0 -0
- langroid/utils/web/login.py +0 -83
- langroid/vector_store/__init__.py +0 -50
- langroid/vector_store/base.py +0 -357
- langroid/vector_store/chromadb.py +0 -214
- langroid/vector_store/lancedb.py +0 -401
- langroid/vector_store/meilisearch.py +0 -299
- langroid/vector_store/momento.py +0 -278
- langroid/vector_store/qdrant_cloud.py +0 -6
- langroid/vector_store/qdrantdb.py +0 -468
- langroid-0.31.2.dist-info/RECORD +0 -162
- {langroid-0.31.2.dist-info → langroid-0.33.3.dist-info/licenses}/LICENSE +0 -0
@@ -1,153 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
If run as a script, starts an RPC server which handles remote
|
3
|
-
embedding requests:
|
4
|
-
|
5
|
-
For example:
|
6
|
-
python3 -m langroid.embedding_models.remote_embeds --port `port`
|
7
|
-
|
8
|
-
where `port` is the port at which the service is exposed. Currently,
|
9
|
-
supports insecure connections only, and this should NOT be exposed to
|
10
|
-
the internet.
|
11
|
-
"""
|
12
|
-
|
13
|
-
import atexit
|
14
|
-
import subprocess
|
15
|
-
import time
|
16
|
-
from typing import Callable, Optional
|
17
|
-
|
18
|
-
import grpc
|
19
|
-
from fire import Fire
|
20
|
-
|
21
|
-
import langroid.embedding_models.models as em
|
22
|
-
import langroid.embedding_models.protoc.embeddings_pb2 as embeddings_pb
|
23
|
-
import langroid.embedding_models.protoc.embeddings_pb2_grpc as embeddings_grpc
|
24
|
-
from langroid.mytypes import Embeddings
|
25
|
-
|
26
|
-
|
27
|
-
class RemoteEmbeddingRPCs(embeddings_grpc.EmbeddingServicer):
|
28
|
-
def __init__(
|
29
|
-
self,
|
30
|
-
model_name: str,
|
31
|
-
batch_size: int,
|
32
|
-
data_parallel: bool,
|
33
|
-
device: Optional[str],
|
34
|
-
devices: Optional[list[str]],
|
35
|
-
):
|
36
|
-
super().__init__()
|
37
|
-
self.embedding_fn = em.SentenceTransformerEmbeddings(
|
38
|
-
em.SentenceTransformerEmbeddingsConfig(
|
39
|
-
model_name=model_name,
|
40
|
-
batch_size=batch_size,
|
41
|
-
data_parallel=data_parallel,
|
42
|
-
device=device,
|
43
|
-
devices=devices,
|
44
|
-
)
|
45
|
-
).embedding_fn()
|
46
|
-
|
47
|
-
def Embed(
|
48
|
-
self, request: embeddings_pb.EmbeddingRequest, _: grpc.RpcContext
|
49
|
-
) -> embeddings_pb.BatchEmbeds:
|
50
|
-
embeds = self.embedding_fn(list(request.strings))
|
51
|
-
|
52
|
-
embeds_pb = [embeddings_pb.Embed(embed=e) for e in embeds]
|
53
|
-
|
54
|
-
return embeddings_pb.BatchEmbeds(embeds=embeds_pb)
|
55
|
-
|
56
|
-
|
57
|
-
class RemoteEmbeddingsConfig(em.SentenceTransformerEmbeddingsConfig):
|
58
|
-
api_base: str = "localhost"
|
59
|
-
port: int = 50052
|
60
|
-
# The below are used only when waiting for server creation
|
61
|
-
poll_delay: float = 0.01
|
62
|
-
max_retries: int = 1000
|
63
|
-
|
64
|
-
|
65
|
-
class RemoteEmbeddings(em.SentenceTransformerEmbeddings):
|
66
|
-
def __init__(self, config: RemoteEmbeddingsConfig = RemoteEmbeddingsConfig()):
|
67
|
-
super().__init__(config)
|
68
|
-
self.config: RemoteEmbeddingsConfig = config
|
69
|
-
self.have_started_server: bool = False
|
70
|
-
|
71
|
-
def embedding_fn(self) -> Callable[[list[str]], Embeddings]:
|
72
|
-
def fn(texts: list[str]) -> Embeddings:
|
73
|
-
url = f"{self.config.api_base}:{self.config.port}"
|
74
|
-
with grpc.insecure_channel(url) as channel:
|
75
|
-
stub = embeddings_grpc.EmbeddingStub(channel) # type: ignore
|
76
|
-
response = stub.Embed(
|
77
|
-
embeddings_pb.EmbeddingRequest(
|
78
|
-
strings=texts,
|
79
|
-
)
|
80
|
-
)
|
81
|
-
|
82
|
-
return [list(emb.embed) for emb in response.embeds]
|
83
|
-
|
84
|
-
def with_handling(texts: list[str]) -> Embeddings:
|
85
|
-
# In local mode, start the server if it has not already
|
86
|
-
# been started
|
87
|
-
if self.config.api_base == "localhost" and not self.have_started_server:
|
88
|
-
try:
|
89
|
-
return fn(texts)
|
90
|
-
# Occurs when the server hasn't been started
|
91
|
-
except grpc.RpcError:
|
92
|
-
self.have_started_server = True
|
93
|
-
# Start the server
|
94
|
-
proc = subprocess.Popen(
|
95
|
-
[
|
96
|
-
"python3",
|
97
|
-
__file__,
|
98
|
-
"--bind_address_base",
|
99
|
-
self.config.api_base,
|
100
|
-
"--port",
|
101
|
-
str(self.config.port),
|
102
|
-
"--batch_size",
|
103
|
-
str(self.config.batch_size),
|
104
|
-
"--model_name",
|
105
|
-
self.config.model_name,
|
106
|
-
],
|
107
|
-
)
|
108
|
-
|
109
|
-
atexit.register(lambda: proc.terminate())
|
110
|
-
|
111
|
-
for _ in range(self.config.max_retries - 1):
|
112
|
-
try:
|
113
|
-
return fn(texts)
|
114
|
-
except grpc.RpcError:
|
115
|
-
time.sleep(self.config.poll_delay)
|
116
|
-
|
117
|
-
# The remote is not local or we have exhausted retries
|
118
|
-
# We should now raise an error if the server is not accessible
|
119
|
-
return fn(texts)
|
120
|
-
|
121
|
-
return with_handling
|
122
|
-
|
123
|
-
|
124
|
-
async def serve(
|
125
|
-
bind_address_base: str = "localhost",
|
126
|
-
port: int = 50052,
|
127
|
-
batch_size: int = 512,
|
128
|
-
data_parallel: bool = False,
|
129
|
-
device: Optional[str] = None,
|
130
|
-
devices: Optional[list[str]] = None,
|
131
|
-
model_name: str = "BAAI/bge-large-en-v1.5",
|
132
|
-
) -> None:
|
133
|
-
"""Starts the RPC server."""
|
134
|
-
server = grpc.aio.server()
|
135
|
-
embeddings_grpc.add_EmbeddingServicer_to_server(
|
136
|
-
RemoteEmbeddingRPCs(
|
137
|
-
model_name=model_name,
|
138
|
-
batch_size=batch_size,
|
139
|
-
data_parallel=data_parallel,
|
140
|
-
device=device,
|
141
|
-
devices=devices,
|
142
|
-
),
|
143
|
-
server,
|
144
|
-
) # type: ignore
|
145
|
-
url = f"{bind_address_base}:{port}"
|
146
|
-
server.add_insecure_port(url)
|
147
|
-
await server.start()
|
148
|
-
print(f"Embedding server started, listening on {url}")
|
149
|
-
await server.wait_for_termination()
|
150
|
-
|
151
|
-
|
152
|
-
if __name__ == "__main__":
|
153
|
-
Fire(serve)
|
langroid/exceptions.py
DELETED
@@ -1,65 +0,0 @@
|
|
1
|
-
from typing import Optional
|
2
|
-
|
3
|
-
|
4
|
-
class XMLException(Exception):
|
5
|
-
def __init__(self, message: str) -> None:
|
6
|
-
super().__init__(message)
|
7
|
-
|
8
|
-
|
9
|
-
class InfiniteLoopException(Exception):
|
10
|
-
def __init__(self, message: str = "Infinite loop detected", *args: object) -> None:
|
11
|
-
super().__init__(message, *args)
|
12
|
-
|
13
|
-
|
14
|
-
class LangroidImportError(ImportError):
|
15
|
-
def __init__(
|
16
|
-
self,
|
17
|
-
package: Optional[str] = None,
|
18
|
-
extra: Optional[str] = None,
|
19
|
-
error: str = "",
|
20
|
-
*args: object,
|
21
|
-
) -> None:
|
22
|
-
"""
|
23
|
-
Generate helpful warning when attempting to import package or module.
|
24
|
-
|
25
|
-
Args:
|
26
|
-
package (str): The name of the package to import.
|
27
|
-
extra (str): The name of the extras package required for this import.
|
28
|
-
error (str): The error message to display. Depending on context, we
|
29
|
-
can set this by capturing the ImportError message.
|
30
|
-
|
31
|
-
"""
|
32
|
-
if error == "" and package is not None:
|
33
|
-
error = f"{package} is not installed by default with Langroid.\n"
|
34
|
-
|
35
|
-
if extra:
|
36
|
-
install_help = f"""
|
37
|
-
If you want to use it, please install langroid
|
38
|
-
with the `{extra}` extra, for example:
|
39
|
-
|
40
|
-
If you are using pip:
|
41
|
-
pip install "langroid[{extra}]"
|
42
|
-
|
43
|
-
For multiple extras, you can separate them with commas:
|
44
|
-
pip install "langroid[{extra},another-extra]"
|
45
|
-
|
46
|
-
If you are using Poetry:
|
47
|
-
poetry add langroid --extras "{extra}"
|
48
|
-
|
49
|
-
For multiple extras with Poetry, list them with spaces:
|
50
|
-
poetry add langroid --extras "{extra} another-extra"
|
51
|
-
|
52
|
-
If you are working within the langroid dev env (which uses Poetry),
|
53
|
-
you can do:
|
54
|
-
poetry install -E "{extra}"
|
55
|
-
or if you want to include multiple extras:
|
56
|
-
poetry install -E "{extra} another-extra"
|
57
|
-
"""
|
58
|
-
else:
|
59
|
-
install_help = """
|
60
|
-
If you want to use it, please install it in the same
|
61
|
-
virtual environment as langroid.
|
62
|
-
"""
|
63
|
-
msg = error + install_help
|
64
|
-
|
65
|
-
super().__init__(msg, *args)
|
@@ -1,391 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from abc import ABC, abstractmethod
|
3
|
-
from typing import Callable, Dict, List, Optional, Union
|
4
|
-
|
5
|
-
import langroid as lr
|
6
|
-
from langroid.language_models.mock_lm import MockLMConfig
|
7
|
-
|
8
|
-
# Fix logging level type
|
9
|
-
logging.basicConfig(level=logging.WARNING)
|
10
|
-
logger = logging.getLogger(__name__)
|
11
|
-
|
12
|
-
|
13
|
-
def sum_fn(s: str) -> str:
|
14
|
-
"""Dummy response for MockLM"""
|
15
|
-
nums = [
|
16
|
-
int(subpart)
|
17
|
-
for part in s.split()
|
18
|
-
for subpart in part.split(",")
|
19
|
-
if subpart.isdigit()
|
20
|
-
]
|
21
|
-
return str(sum(nums) + 1)
|
22
|
-
|
23
|
-
|
24
|
-
def user_message(msg: Union[str, lr.ChatDocument]) -> lr.ChatDocument:
|
25
|
-
"""Create a user-role msg from a string or ChatDocument"""
|
26
|
-
if isinstance(msg, lr.ChatDocument):
|
27
|
-
return msg
|
28
|
-
return lr.ChatDocument(
|
29
|
-
content=msg,
|
30
|
-
metadata=lr.ChatDocMetaData(
|
31
|
-
sender=lr.Entity.USER,
|
32
|
-
sender_name="user",
|
33
|
-
),
|
34
|
-
)
|
35
|
-
|
36
|
-
|
37
|
-
class InputContext:
|
38
|
-
"""Context for a Component to respond to"""
|
39
|
-
|
40
|
-
def __init__(self) -> None:
|
41
|
-
self.messages: List[lr.ChatDocument] = []
|
42
|
-
|
43
|
-
def add(
|
44
|
-
self, results: Union[str, List[str], lr.ChatDocument, List[lr.ChatDocument]]
|
45
|
-
) -> None:
|
46
|
-
"""
|
47
|
-
Add messages to the input messages list
|
48
|
-
"""
|
49
|
-
msgs: List[lr.ChatDocument] = []
|
50
|
-
if isinstance(results, str):
|
51
|
-
msgs = [user_message(results)]
|
52
|
-
elif isinstance(results, lr.ChatDocument):
|
53
|
-
msgs = [results]
|
54
|
-
elif isinstance(results, list):
|
55
|
-
if len(results) == 0:
|
56
|
-
return
|
57
|
-
if isinstance(results[0], str):
|
58
|
-
msgs = [user_message(r) for r in results]
|
59
|
-
else:
|
60
|
-
msgs = [r for r in results if isinstance(r, lr.ChatDocument)]
|
61
|
-
self.messages.extend(msgs)
|
62
|
-
|
63
|
-
def clear(self) -> None:
|
64
|
-
self.messages.clear()
|
65
|
-
|
66
|
-
def get_context(self) -> lr.ChatDocument:
|
67
|
-
"""Construct a user-role ChatDocument from the input messages"""
|
68
|
-
if len(self.messages) == 0:
|
69
|
-
return lr.ChatDocument(content="", metadata={"sender": lr.Entity.USER})
|
70
|
-
content = "\n".join(
|
71
|
-
f"{msg.metadata.sender_name}: {msg.content}" for msg in self.messages
|
72
|
-
)
|
73
|
-
return lr.ChatDocument(content=content, metadata={"sender": lr.Entity.USER})
|
74
|
-
|
75
|
-
|
76
|
-
class Scheduler(ABC):
|
77
|
-
"""Schedule the Components of a Team"""
|
78
|
-
|
79
|
-
def __init__(self) -> None:
|
80
|
-
self.init_state()
|
81
|
-
|
82
|
-
def init_state(self) -> None:
|
83
|
-
self.stepped = False
|
84
|
-
self.responders: List[str] = []
|
85
|
-
self.responder_counts: Dict[str, int] = {}
|
86
|
-
self.current_result: List[lr.ChatDocument] = []
|
87
|
-
|
88
|
-
@abstractmethod
|
89
|
-
def step(self) -> None:
|
90
|
-
pass
|
91
|
-
|
92
|
-
@abstractmethod
|
93
|
-
def done(self) -> bool:
|
94
|
-
pass
|
95
|
-
|
96
|
-
@abstractmethod
|
97
|
-
def result(self) -> List[lr.ChatDocument]:
|
98
|
-
pass
|
99
|
-
|
100
|
-
def run(self) -> List[lr.ChatDocument]:
|
101
|
-
self.init_state()
|
102
|
-
while not self.done():
|
103
|
-
self.step()
|
104
|
-
return self.result()
|
105
|
-
|
106
|
-
|
107
|
-
class Component(ABC):
|
108
|
-
"""A component of a Team"""
|
109
|
-
|
110
|
-
def __init__(self) -> None:
|
111
|
-
self.input = InputContext()
|
112
|
-
self._listeners: List["Component"] = []
|
113
|
-
self.name: str = ""
|
114
|
-
|
115
|
-
@abstractmethod
|
116
|
-
def run(self) -> List[lr.ChatDocument]:
|
117
|
-
pass
|
118
|
-
|
119
|
-
def listen(self, component: Union["Component", List["Component"]]) -> None:
|
120
|
-
if isinstance(component, list):
|
121
|
-
for comp in component:
|
122
|
-
comp.listeners.append(self)
|
123
|
-
else:
|
124
|
-
component.listeners.append(self)
|
125
|
-
|
126
|
-
@property
|
127
|
-
def listeners(self) -> List["Component"]:
|
128
|
-
return self._listeners
|
129
|
-
|
130
|
-
def _notify(self, results: List[lr.ChatDocument]) -> None:
|
131
|
-
logger.warning(f"{self.name} Notifying listeners...")
|
132
|
-
for listener in self.listeners:
|
133
|
-
logger.warning(f"--> Listener {listener.name} notified")
|
134
|
-
listener.input.add(results)
|
135
|
-
|
136
|
-
|
137
|
-
class SimpleScheduler(Scheduler):
|
138
|
-
def __init__(
|
139
|
-
self,
|
140
|
-
components: List[Component],
|
141
|
-
) -> None:
|
142
|
-
super().__init__()
|
143
|
-
self.components = components # Get components from team
|
144
|
-
self.stepped: bool = False
|
145
|
-
|
146
|
-
def step(self) -> None:
|
147
|
-
results = []
|
148
|
-
for comp in self.components:
|
149
|
-
result = comp.run()
|
150
|
-
if result:
|
151
|
-
results.extend(result)
|
152
|
-
self.current_result = results
|
153
|
-
self.stepped = True
|
154
|
-
|
155
|
-
def done(self) -> bool:
|
156
|
-
"""done after 1 step, i.e. all components have responded"""
|
157
|
-
return self.stepped
|
158
|
-
|
159
|
-
def result(self) -> List[lr.ChatDocument]:
|
160
|
-
return self.current_result
|
161
|
-
|
162
|
-
|
163
|
-
class OrElseScheduler(Scheduler):
|
164
|
-
"""
|
165
|
-
Implements "OrElse scheduling", i.e. if the components are A, B, C, then
|
166
|
-
in each step, it will try for a valid response from A OrElse B OrElse C,
|
167
|
-
i.e. the first component that gives a valid response is chosen.
|
168
|
-
In the next step, it will start from the next component in the list,
|
169
|
-
cycling back to the first component after the last component.
|
170
|
-
(There may be a better name than OrElseScheduler though.)
|
171
|
-
"""
|
172
|
-
|
173
|
-
def __init__(
|
174
|
-
self,
|
175
|
-
components: List[Component],
|
176
|
-
) -> None:
|
177
|
-
super().__init__()
|
178
|
-
self.components = components
|
179
|
-
self.team: Optional[Team] = None
|
180
|
-
self.current_index: int = 0
|
181
|
-
|
182
|
-
def init_state(self) -> None:
|
183
|
-
super().init_state()
|
184
|
-
self.current_index = 0
|
185
|
-
|
186
|
-
def is_valid(self, result: Optional[List[lr.ChatDocument]]) -> bool:
|
187
|
-
return result is not None and len(result) > 0
|
188
|
-
|
189
|
-
def step(self) -> None:
|
190
|
-
start_index = self.current_index
|
191
|
-
n = len(self.components)
|
192
|
-
|
193
|
-
for i in range(n):
|
194
|
-
idx = (start_index + i) % n
|
195
|
-
comp = self.components[idx]
|
196
|
-
result = comp.run()
|
197
|
-
if self.is_valid(result):
|
198
|
-
self.responders.append(comp.name)
|
199
|
-
self.responder_counts[comp.name] = (
|
200
|
-
self.responder_counts.get(comp.name, 0) + 1
|
201
|
-
)
|
202
|
-
self.current_result = result
|
203
|
-
# cycle to next component
|
204
|
-
self.current_index = (idx + 1) % n
|
205
|
-
return
|
206
|
-
|
207
|
-
def done(self) -> bool:
|
208
|
-
if self.team is None:
|
209
|
-
return False
|
210
|
-
return self.team.done(self)
|
211
|
-
|
212
|
-
def result(self) -> List[lr.ChatDocument]:
|
213
|
-
return self.current_result
|
214
|
-
|
215
|
-
|
216
|
-
class Team(Component):
|
217
|
-
def __init__(
|
218
|
-
self,
|
219
|
-
name: str,
|
220
|
-
done_condition: Optional[Callable[["Team", Scheduler], bool]] = None,
|
221
|
-
) -> None:
|
222
|
-
super().__init__()
|
223
|
-
self.name = name
|
224
|
-
self.components: List[Component] = []
|
225
|
-
self.scheduler: Optional[Scheduler] = None
|
226
|
-
self.done_condition = done_condition or Team.default_done_condition
|
227
|
-
|
228
|
-
def set_done_condition(
|
229
|
-
self, done_condition: Callable[["Team", Scheduler], bool]
|
230
|
-
) -> None:
|
231
|
-
self.done_condition = done_condition
|
232
|
-
|
233
|
-
def done(self, scheduler: Scheduler) -> bool:
|
234
|
-
return self.done_condition(self, scheduler)
|
235
|
-
|
236
|
-
def default_done_condition(self, scheduler: Scheduler) -> bool:
|
237
|
-
# Default condition, can be overridden
|
238
|
-
return False
|
239
|
-
|
240
|
-
def add_scheduler(self, scheduler_class: type) -> None:
|
241
|
-
self.scheduler = scheduler_class(self.components)
|
242
|
-
if hasattr(self.scheduler, "team"):
|
243
|
-
setattr(self.scheduler, "team", self)
|
244
|
-
|
245
|
-
def add(self, component: Union[Component, List[Component]]) -> None:
|
246
|
-
if isinstance(component, list):
|
247
|
-
self.components.extend(component)
|
248
|
-
else:
|
249
|
-
self.components.append(component)
|
250
|
-
|
251
|
-
def reset(self) -> None:
|
252
|
-
self.input.clear()
|
253
|
-
if self.scheduler is not None:
|
254
|
-
self.scheduler.init_state()
|
255
|
-
|
256
|
-
def run(self, input: str | lr.ChatDocument | None = None) -> List[lr.ChatDocument]:
|
257
|
-
if input is not None:
|
258
|
-
self.input.add(input)
|
259
|
-
if self.scheduler is None:
|
260
|
-
raise ValueError(
|
261
|
-
f"Team '{self.name}' has no scheduler. Call add_scheduler() first."
|
262
|
-
)
|
263
|
-
input_str = self.input.get_context().content
|
264
|
-
logger.warning(f"Running team {self.name}... on input = {input_str}")
|
265
|
-
# push the input of self to each component that's a listener of self.
|
266
|
-
n_pushed = 0
|
267
|
-
for comp in self.components:
|
268
|
-
|
269
|
-
if comp in self.listeners:
|
270
|
-
comp.input.add(self.input.messages)
|
271
|
-
n_pushed += 1
|
272
|
-
if len(self.input.messages) > 0 and n_pushed == 0:
|
273
|
-
logger.warning(
|
274
|
-
"""
|
275
|
-
Warning: Team inputs not pushed to any components!
|
276
|
-
You may not be able to run any components unless they have their
|
277
|
-
own inputs. Make sure to set up component to listen to parent team
|
278
|
-
if needed.
|
279
|
-
"""
|
280
|
-
)
|
281
|
-
# clear own input since we've pushed it to internal components
|
282
|
-
self.input.clear()
|
283
|
-
|
284
|
-
result = self.scheduler.run()
|
285
|
-
if len(result) > 0:
|
286
|
-
self._notify(result)
|
287
|
-
result_value = result[0].content if len(result) > 0 else "null"
|
288
|
-
logger.warning(f"Team {self.name} done: {result_value}")
|
289
|
-
return result
|
290
|
-
|
291
|
-
|
292
|
-
class DummyAgent:
|
293
|
-
def __init__(self, name: str) -> None:
|
294
|
-
self.name = name
|
295
|
-
|
296
|
-
def process(self, data: str) -> str:
|
297
|
-
return f"{self.name} processed: {data}"
|
298
|
-
|
299
|
-
|
300
|
-
class TaskComponent(Component):
|
301
|
-
def __init__(self, task: lr.Task) -> None:
|
302
|
-
super().__init__()
|
303
|
-
self.task = task
|
304
|
-
self.name = task.agent.config.name
|
305
|
-
|
306
|
-
def run(self, input: str | lr.ChatDocument | None = None) -> List[lr.ChatDocument]:
|
307
|
-
if input is not None:
|
308
|
-
self.input.add(input)
|
309
|
-
input_msg = self.input.get_context()
|
310
|
-
if input_msg.content == "":
|
311
|
-
return []
|
312
|
-
logger.warning(f"Running task {self.name} on input = {input_msg.content}")
|
313
|
-
result = self.task.run(input_msg)
|
314
|
-
result_value = result.content if result else "null"
|
315
|
-
logger.warning(f"Task {self.name} done: {result_value}")
|
316
|
-
result_list = [result] if result else []
|
317
|
-
if len(result_list) > 0:
|
318
|
-
self._notify(result_list)
|
319
|
-
self.input.clear() # clear own input since we just consumed it!
|
320
|
-
return result_list
|
321
|
-
|
322
|
-
|
323
|
-
def make_task(name: str, sys: str = "") -> TaskComponent:
|
324
|
-
llm_config = MockLMConfig(response_fn=sum_fn)
|
325
|
-
agent = lr.ChatAgent(
|
326
|
-
lr.ChatAgentConfig(
|
327
|
-
llm=llm_config,
|
328
|
-
name=name,
|
329
|
-
)
|
330
|
-
)
|
331
|
-
# set as single_round since there are no Tools
|
332
|
-
task = lr.Task(agent, interactive=False, single_round=True)
|
333
|
-
return TaskComponent(task)
|
334
|
-
|
335
|
-
|
336
|
-
if __name__ == "__main__":
|
337
|
-
# Create agents, tasks
|
338
|
-
t1 = make_task("a1")
|
339
|
-
t2 = make_task("a2")
|
340
|
-
t3 = make_task("a3")
|
341
|
-
|
342
|
-
# done conditions for each time
|
343
|
-
def team1_done_condition(team: Team, scheduler: Scheduler) -> bool:
|
344
|
-
return (
|
345
|
-
scheduler.responder_counts.get("a1", 0) >= 2
|
346
|
-
and scheduler.responder_counts.get("a2", 0) >= 2
|
347
|
-
)
|
348
|
-
|
349
|
-
def team2_done_condition(team: Team, scheduler: Scheduler) -> bool:
|
350
|
-
return "a3" in scheduler.responders
|
351
|
-
|
352
|
-
def general_team_done_condition(team: Team, scheduler: Scheduler) -> bool:
|
353
|
-
# Example: all components have responded at least once
|
354
|
-
return len(set(scheduler.responders)) == len(team.components)
|
355
|
-
|
356
|
-
# Create teams
|
357
|
-
team1 = Team("T1", done_condition=team1_done_condition)
|
358
|
-
team2 = Team("T2", done_condition=team2_done_condition)
|
359
|
-
|
360
|
-
team = Team("Team", done_condition=general_team_done_condition)
|
361
|
-
|
362
|
-
team1.add_scheduler(OrElseScheduler)
|
363
|
-
team2.add_scheduler(OrElseScheduler)
|
364
|
-
team.add_scheduler(OrElseScheduler)
|
365
|
-
|
366
|
-
team.add([team1, team2])
|
367
|
-
|
368
|
-
# Build hierarchy
|
369
|
-
team1.add([t1, t2])
|
370
|
-
team2.add(t3)
|
371
|
-
|
372
|
-
# Set up listening
|
373
|
-
# team2.listen(team1) # listens to team1 final result
|
374
|
-
team1.listen(team)
|
375
|
-
t1.listen(team1)
|
376
|
-
t2.listen(t1)
|
377
|
-
t1.listen(t2)
|
378
|
-
# TODO should we forbid listening to a component OUTSIDE the team?
|
379
|
-
|
380
|
-
# t3 listens to its parent team2 =>
|
381
|
-
# any input to team2 gets pushed to t3 when t3 runs
|
382
|
-
team2.listen([t1, t2])
|
383
|
-
t3.listen(team2)
|
384
|
-
|
385
|
-
# TODO - we should either define which component of a team gets the teams inputs,
|
386
|
-
# or explicitly add messages to a specific component of the team
|
387
|
-
|
388
|
-
print("Running top-level team...")
|
389
|
-
result = team.run("1")
|
390
|
-
|
391
|
-
##########
|