langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -0,0 +1,153 @@
|
|
1
|
+
"""
|
2
|
+
If run as a script, starts an RPC server which handles remote
|
3
|
+
embedding requests:
|
4
|
+
|
5
|
+
For example:
|
6
|
+
python3 -m langroid.embedding_models.remote_embeds --port `port`
|
7
|
+
|
8
|
+
where `port` is the port at which the service is exposed. Currently,
|
9
|
+
supports insecure connections only, and this should NOT be exposed to
|
10
|
+
the internet.
|
11
|
+
"""
|
12
|
+
|
13
|
+
import atexit
|
14
|
+
import subprocess
|
15
|
+
import time
|
16
|
+
from typing import Callable, Optional
|
17
|
+
|
18
|
+
import grpc
|
19
|
+
from fire import Fire
|
20
|
+
|
21
|
+
import langroid.embedding_models.models as em
|
22
|
+
import langroid.embedding_models.protoc.embeddings_pb2 as embeddings_pb
|
23
|
+
import langroid.embedding_models.protoc.embeddings_pb2_grpc as embeddings_grpc
|
24
|
+
from langroid.mytypes import Embeddings
|
25
|
+
|
26
|
+
|
27
|
+
class RemoteEmbeddingRPCs(embeddings_grpc.EmbeddingServicer):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
model_name: str,
|
31
|
+
batch_size: int,
|
32
|
+
data_parallel: bool,
|
33
|
+
device: Optional[str],
|
34
|
+
devices: Optional[list[str]],
|
35
|
+
):
|
36
|
+
super().__init__()
|
37
|
+
self.embedding_fn = em.SentenceTransformerEmbeddings(
|
38
|
+
em.SentenceTransformerEmbeddingsConfig(
|
39
|
+
model_name=model_name,
|
40
|
+
batch_size=batch_size,
|
41
|
+
data_parallel=data_parallel,
|
42
|
+
device=device,
|
43
|
+
devices=devices,
|
44
|
+
)
|
45
|
+
).embedding_fn()
|
46
|
+
|
47
|
+
def Embed(
|
48
|
+
self, request: embeddings_pb.EmbeddingRequest, _: grpc.RpcContext
|
49
|
+
) -> embeddings_pb.BatchEmbeds:
|
50
|
+
embeds = self.embedding_fn(list(request.strings))
|
51
|
+
|
52
|
+
embeds_pb = [embeddings_pb.Embed(embed=e) for e in embeds]
|
53
|
+
|
54
|
+
return embeddings_pb.BatchEmbeds(embeds=embeds_pb)
|
55
|
+
|
56
|
+
|
57
|
+
class RemoteEmbeddingsConfig(em.SentenceTransformerEmbeddingsConfig):
|
58
|
+
api_base: str = "localhost"
|
59
|
+
port: int = 50052
|
60
|
+
# The below are used only when waiting for server creation
|
61
|
+
poll_delay: float = 0.01
|
62
|
+
max_retries: int = 1000
|
63
|
+
|
64
|
+
|
65
|
+
class RemoteEmbeddings(em.SentenceTransformerEmbeddings):
|
66
|
+
def __init__(self, config: RemoteEmbeddingsConfig = RemoteEmbeddingsConfig()):
|
67
|
+
super().__init__(config)
|
68
|
+
self.config: RemoteEmbeddingsConfig = config
|
69
|
+
self.have_started_server: bool = False
|
70
|
+
|
71
|
+
def embedding_fn(self) -> Callable[[list[str]], Embeddings]:
|
72
|
+
def fn(texts: list[str]) -> Embeddings:
|
73
|
+
url = f"{self.config.api_base}:{self.config.port}"
|
74
|
+
with grpc.insecure_channel(url) as channel:
|
75
|
+
stub = embeddings_grpc.EmbeddingStub(channel) # type: ignore
|
76
|
+
response = stub.Embed(
|
77
|
+
embeddings_pb.EmbeddingRequest(
|
78
|
+
strings=texts,
|
79
|
+
)
|
80
|
+
)
|
81
|
+
|
82
|
+
return [list(emb.embed) for emb in response.embeds]
|
83
|
+
|
84
|
+
def with_handling(texts: list[str]) -> Embeddings:
|
85
|
+
# In local mode, start the server if it has not already
|
86
|
+
# been started
|
87
|
+
if self.config.api_base == "localhost" and not self.have_started_server:
|
88
|
+
try:
|
89
|
+
return fn(texts)
|
90
|
+
# Occurs when the server hasn't been started
|
91
|
+
except grpc.RpcError:
|
92
|
+
self.have_started_server = True
|
93
|
+
# Start the server
|
94
|
+
proc = subprocess.Popen(
|
95
|
+
[
|
96
|
+
"python3",
|
97
|
+
__file__,
|
98
|
+
"--bind_address_base",
|
99
|
+
self.config.api_base,
|
100
|
+
"--port",
|
101
|
+
str(self.config.port),
|
102
|
+
"--batch_size",
|
103
|
+
str(self.config.batch_size),
|
104
|
+
"--model_name",
|
105
|
+
self.config.model_name,
|
106
|
+
],
|
107
|
+
)
|
108
|
+
|
109
|
+
atexit.register(lambda: proc.terminate())
|
110
|
+
|
111
|
+
for _ in range(self.config.max_retries - 1):
|
112
|
+
try:
|
113
|
+
return fn(texts)
|
114
|
+
except grpc.RpcError:
|
115
|
+
time.sleep(self.config.poll_delay)
|
116
|
+
|
117
|
+
# The remote is not local or we have exhausted retries
|
118
|
+
# We should now raise an error if the server is not accessible
|
119
|
+
return fn(texts)
|
120
|
+
|
121
|
+
return with_handling
|
122
|
+
|
123
|
+
|
124
|
+
async def serve(
|
125
|
+
bind_address_base: str = "localhost",
|
126
|
+
port: int = 50052,
|
127
|
+
batch_size: int = 512,
|
128
|
+
data_parallel: bool = False,
|
129
|
+
device: Optional[str] = None,
|
130
|
+
devices: Optional[list[str]] = None,
|
131
|
+
model_name: str = "BAAI/bge-large-en-v1.5",
|
132
|
+
) -> None:
|
133
|
+
"""Starts the RPC server."""
|
134
|
+
server = grpc.aio.server()
|
135
|
+
embeddings_grpc.add_EmbeddingServicer_to_server(
|
136
|
+
RemoteEmbeddingRPCs(
|
137
|
+
model_name=model_name,
|
138
|
+
batch_size=batch_size,
|
139
|
+
data_parallel=data_parallel,
|
140
|
+
device=device,
|
141
|
+
devices=devices,
|
142
|
+
),
|
143
|
+
server,
|
144
|
+
) # type: ignore
|
145
|
+
url = f"{bind_address_base}:{port}"
|
146
|
+
server.add_insecure_port(url)
|
147
|
+
await server.start()
|
148
|
+
print(f"Embedding server started, listening on {url}")
|
149
|
+
await server.wait_for_termination()
|
150
|
+
|
151
|
+
|
152
|
+
if __name__ == "__main__":
|
153
|
+
Fire(serve)
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from . import utils
|
2
|
+
from . import config
|
3
|
+
from . import base
|
4
|
+
from . import openai_gpt
|
5
|
+
from . import azure_openai
|
6
|
+
from . import prompt_formatter
|
7
|
+
|
8
|
+
from .base import (
|
9
|
+
LLMConfig,
|
10
|
+
LLMMessage,
|
11
|
+
LLMFunctionCall,
|
12
|
+
LLMFunctionSpec,
|
13
|
+
Role,
|
14
|
+
LLMTokenUsage,
|
15
|
+
LLMResponse,
|
16
|
+
)
|
17
|
+
from .openai_gpt import (
|
18
|
+
OpenAIChatModel,
|
19
|
+
OpenAICompletionModel,
|
20
|
+
OpenAIGPTConfig,
|
21
|
+
OpenAIGPT,
|
22
|
+
)
|
23
|
+
from .azure_openai import AzureConfig, AzureGPT
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"utils",
|
27
|
+
"config",
|
28
|
+
"base",
|
29
|
+
"openai_gpt",
|
30
|
+
"azure_openai",
|
31
|
+
"prompt_formatter",
|
32
|
+
"LLMConfig",
|
33
|
+
"LLMMessage",
|
34
|
+
"LLMFunctionCall",
|
35
|
+
"LLMFunctionSpec",
|
36
|
+
"Role",
|
37
|
+
"LLMTokenUsage",
|
38
|
+
"LLMResponse",
|
39
|
+
"OpenAIChatModel",
|
40
|
+
"OpenAICompletionModel",
|
41
|
+
"OpenAIGPTConfig",
|
42
|
+
"OpenAIGPT",
|
43
|
+
"AzureConfig",
|
44
|
+
"AzureGPT",
|
45
|
+
]
|
@@ -1,7 +1,6 @@
|
|
1
|
-
import os
|
2
|
-
|
3
|
-
import openai
|
4
1
|
from dotenv import load_dotenv
|
2
|
+
from httpx import Timeout
|
3
|
+
from openai import AsyncAzureOpenAI, AzureOpenAI
|
5
4
|
|
6
5
|
from langroid.language_models.openai_gpt import (
|
7
6
|
OpenAIChatModel,
|
@@ -23,12 +22,26 @@ class AzureConfig(OpenAIGPTConfig):
|
|
23
22
|
chose for your deployment when you deployed a model.
|
24
23
|
model_name (str): can be set in the ``.env`` file as ``AZURE_GPT_MODEL_NAME``
|
25
24
|
and should be based on the model name chosen during setup.
|
25
|
+
model_version (str): can be set in the ``.env`` file as
|
26
|
+
``AZURE_OPENAI_MODEL_VERSION`` and should be based on the model name
|
27
|
+
chosen during setup.
|
26
28
|
"""
|
27
29
|
|
30
|
+
api_key: str = "" # CAUTION: set this ONLY via env var AZURE_OPENAI_API_KEY
|
28
31
|
type: str = "azure"
|
29
32
|
api_version: str = "2023-05-15"
|
30
33
|
deployment_name: str = ""
|
31
34
|
model_name: str = ""
|
35
|
+
model_version: str = "" # is used to determine the cost of using the model
|
36
|
+
api_base: str = ""
|
37
|
+
|
38
|
+
# all of the vars above can be set via env vars,
|
39
|
+
# by upper-casing the name and prefixing with `env_prefix`, e.g.
|
40
|
+
# AZURE_OPENAI_API_VERSION=2023-05-15
|
41
|
+
# This is either done in the .env file, or via an explicit
|
42
|
+
# `export AZURE_OPENAI_API_VERSION=...`
|
43
|
+
class Config:
|
44
|
+
env_prefix = "AZURE_OPENAI_"
|
32
45
|
|
33
46
|
|
34
47
|
class AzureGPT(OpenAIGPT):
|
@@ -41,59 +54,99 @@ class AzureGPT(OpenAIGPT):
|
|
41
54
|
api_base (str): Azure API base url
|
42
55
|
api_version (str): Azure API version
|
43
56
|
model_name (str): the name of gpt model in your deployment
|
57
|
+
model_version (str): the version of gpt model in your deployment
|
44
58
|
"""
|
45
59
|
|
46
60
|
def __init__(self, config: AzureConfig):
|
61
|
+
# This will auto-populate config values from .env file
|
62
|
+
load_dotenv()
|
47
63
|
super().__init__(config)
|
48
64
|
self.config: AzureConfig = config
|
49
|
-
self.
|
50
|
-
openai.api_type = self.api_type
|
51
|
-
load_dotenv()
|
52
|
-
self.api_key = os.getenv("AZURE_API_KEY", "")
|
53
|
-
if self.api_key == "":
|
65
|
+
if self.config.api_key == "":
|
54
66
|
raise ValueError(
|
55
67
|
"""
|
56
|
-
|
68
|
+
AZURE_OPENAI_API_KEY not set in .env file,
|
57
69
|
please set it to your Azure API key."""
|
58
70
|
)
|
59
71
|
|
60
|
-
self.api_base
|
61
|
-
if self.api_base == "":
|
72
|
+
if self.config.api_base == "":
|
62
73
|
raise ValueError(
|
63
74
|
"""
|
64
75
|
AZURE_OPENAI_API_BASE not set in .env file,
|
65
76
|
please set it to your Azure API key."""
|
66
77
|
)
|
67
|
-
# we don't need this for ``api_key`` because it's handled inside
|
68
|
-
# ``openai_gpt.py`` methods before invoking chat/completion calls
|
69
|
-
else:
|
70
|
-
openai.api_base = self.api_base
|
71
78
|
|
72
|
-
self.
|
73
|
-
os.getenv("AZURE_OPENAI_API_VERSION", "") or config.api_version
|
74
|
-
)
|
75
|
-
openai.api_version = self.api_version
|
76
|
-
|
77
|
-
self.deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "")
|
78
|
-
if self.deployment_name == "":
|
79
|
+
if self.config.deployment_name == "":
|
79
80
|
raise ValueError(
|
80
81
|
"""
|
81
82
|
AZURE_OPENAI_DEPLOYMENT_NAME not set in .env file,
|
82
83
|
please set it to your Azure openai deployment name."""
|
83
84
|
)
|
85
|
+
self.deployment_name = self.config.deployment_name
|
84
86
|
|
85
|
-
self.model_name
|
86
|
-
if self.model_name == "":
|
87
|
+
if self.config.model_name == "":
|
87
88
|
raise ValueError(
|
88
89
|
"""
|
89
|
-
|
90
|
-
please set it to chat model name in
|
90
|
+
AZURE_OPENAI_MODEL_NAME not set in .env file,
|
91
|
+
please set it to chat model name in your deployment."""
|
91
92
|
)
|
92
93
|
|
93
94
|
# set the chat model to be the same as the model_name
|
94
95
|
# This corresponds to the gpt model you chose for your deployment
|
95
96
|
# when you deployed a model
|
96
|
-
|
97
|
+
self.set_chat_model()
|
98
|
+
|
99
|
+
self.client = AzureOpenAI(
|
100
|
+
api_key=self.config.api_key,
|
101
|
+
azure_endpoint=self.config.api_base,
|
102
|
+
api_version=self.config.api_version,
|
103
|
+
azure_deployment=self.config.deployment_name,
|
104
|
+
)
|
105
|
+
self.async_client = AsyncAzureOpenAI(
|
106
|
+
api_key=self.config.api_key,
|
107
|
+
azure_endpoint=self.config.api_base,
|
108
|
+
api_version=self.config.api_version,
|
109
|
+
azure_deployment=self.config.deployment_name,
|
110
|
+
timeout=Timeout(self.config.timeout),
|
111
|
+
)
|
112
|
+
|
113
|
+
def set_chat_model(self) -> None:
|
114
|
+
"""
|
115
|
+
Sets the chat model configuration based on the model name specified in the
|
116
|
+
``.env``. This function checks the `model_name` in the configuration and sets
|
117
|
+
the appropriate chat model in the `config.chat_model`. It supports handling for
|
118
|
+
'35-turbo' and 'gpt-4' models. For 'gpt-4', it further delegates the handling
|
119
|
+
to `handle_gpt4_model` method. If the model name does not match any predefined
|
120
|
+
models, it defaults to `OpenAIChatModel.GPT4`.
|
121
|
+
"""
|
122
|
+
MODEL_35_TURBO = "35-turbo"
|
123
|
+
MODEL_GPT4 = "gpt-4"
|
124
|
+
|
125
|
+
if self.config.model_name == MODEL_35_TURBO:
|
97
126
|
self.config.chat_model = OpenAIChatModel.GPT3_5_TURBO
|
127
|
+
elif self.config.model_name == MODEL_GPT4:
|
128
|
+
self.handle_gpt4_model()
|
129
|
+
else:
|
130
|
+
self.config.chat_model = OpenAIChatModel.GPT4
|
131
|
+
|
132
|
+
def handle_gpt4_model(self) -> None:
|
133
|
+
"""
|
134
|
+
Handles the setting of the GPT-4 model in the configuration.
|
135
|
+
This function checks the `model_version` in the configuration.
|
136
|
+
If the version is not set, it raises a ValueError indicating that the model
|
137
|
+
version needs to be specified in the ``.env`` file.
|
138
|
+
It sets `OpenAIChatModel.GPT4_TURBO` if the version is
|
139
|
+
'1106-Preview', otherwise, it defaults to setting `OpenAIChatModel.GPT4`.
|
140
|
+
"""
|
141
|
+
VERSION_1106_PREVIEW = "1106-Preview"
|
142
|
+
|
143
|
+
if self.config.model_version == "":
|
144
|
+
raise ValueError(
|
145
|
+
"AZURE_OPENAI_MODEL_VERSION not set in .env file. "
|
146
|
+
"Please set it to the chat model version used in your deployment."
|
147
|
+
)
|
148
|
+
|
149
|
+
if self.config.model_version == VERSION_1106_PREVIEW:
|
150
|
+
self.config.chat_model = OpenAIChatModel.GPT4_TURBO
|
98
151
|
else:
|
99
152
|
self.config.chat_model = OpenAIChatModel.GPT4
|
langroid/language_models/base.py
CHANGED
@@ -1,19 +1,20 @@
|
|
1
|
+
import ast
|
1
2
|
import asyncio
|
2
3
|
import json
|
3
4
|
import logging
|
4
5
|
from abc import ABC, abstractmethod
|
6
|
+
from datetime import datetime
|
5
7
|
from enum import Enum
|
6
|
-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
7
9
|
|
8
10
|
import aiohttp
|
9
|
-
from pydantic import BaseModel, BaseSettings
|
11
|
+
from pydantic import BaseModel, BaseSettings, Field
|
10
12
|
|
11
13
|
from langroid.cachedb.momento_cachedb import MomentoCacheConfig
|
12
14
|
from langroid.cachedb.redis_cachedb import RedisCacheConfig
|
13
|
-
from langroid.language_models.config import Llama2FormatterConfig, PromptFormatterConfig
|
14
15
|
from langroid.mytypes import Document
|
15
16
|
from langroid.parsing.agent_chats import parse_message
|
16
|
-
from langroid.parsing.
|
17
|
+
from langroid.parsing.parse_json import top_level_json_field
|
17
18
|
from langroid.prompts.dialog import collate_chat_history
|
18
19
|
from langroid.prompts.templates import (
|
19
20
|
EXTRACTION_PROMPT_GPT4,
|
@@ -25,15 +26,21 @@ from langroid.utils.output.printing import show_if_debug
|
|
25
26
|
logger = logging.getLogger(__name__)
|
26
27
|
|
27
28
|
|
29
|
+
def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
|
30
|
+
pass
|
31
|
+
|
32
|
+
|
28
33
|
class LLMConfig(BaseSettings):
|
29
34
|
type: str = "openai"
|
30
|
-
|
35
|
+
streamer: Optional[Callable[[Any], None]] = noop_fn
|
36
|
+
api_base: str | None = None
|
37
|
+
formatter: None | str = None
|
31
38
|
timeout: int = 20 # timeout for API requests
|
32
39
|
chat_model: str = ""
|
33
40
|
completion_model: str = ""
|
34
41
|
temperature: float = 0.0
|
35
|
-
chat_context_length: int =
|
36
|
-
completion_context_length: int =
|
42
|
+
chat_context_length: int = 8000
|
43
|
+
completion_context_length: int = 8000
|
37
44
|
max_output_tokens: int = 1024 # generate at most this many tokens
|
38
45
|
# if input length + max_output_tokens > context length of model,
|
39
46
|
# we will try shortening requested output
|
@@ -59,6 +66,26 @@ class LLMFunctionCall(BaseModel):
|
|
59
66
|
to: str = "" # intended recipient
|
60
67
|
arguments: Optional[Dict[str, Any]] = None
|
61
68
|
|
69
|
+
@staticmethod
|
70
|
+
def from_dict(message: Dict[str, Any]) -> "LLMFunctionCall":
|
71
|
+
"""
|
72
|
+
Initialize from dictionary.
|
73
|
+
Args:
|
74
|
+
d: dictionary containing fields to initialize
|
75
|
+
"""
|
76
|
+
fun_call = LLMFunctionCall(name=message["name"])
|
77
|
+
fun_args_str = message["arguments"]
|
78
|
+
# sometimes may be malformed with invalid indents,
|
79
|
+
# so we try to be safe by removing newlines.
|
80
|
+
if fun_args_str is not None:
|
81
|
+
fun_args_str = fun_args_str.replace("\n", "").strip()
|
82
|
+
fun_args = ast.literal_eval(fun_args_str)
|
83
|
+
else:
|
84
|
+
fun_args = None
|
85
|
+
fun_call.arguments = fun_args
|
86
|
+
|
87
|
+
return fun_call
|
88
|
+
|
62
89
|
def __str__(self) -> str:
|
63
90
|
return "FUNC: " + json.dumps(self.dict(), indent=2)
|
64
91
|
|
@@ -79,6 +106,20 @@ class LLMTokenUsage(BaseModel):
|
|
79
106
|
prompt_tokens: int = 0
|
80
107
|
completion_tokens: int = 0
|
81
108
|
cost: float = 0.0
|
109
|
+
calls: int = 0 # how many API calls
|
110
|
+
|
111
|
+
def reset(self) -> None:
|
112
|
+
self.prompt_tokens = 0
|
113
|
+
self.completion_tokens = 0
|
114
|
+
self.cost = 0.0
|
115
|
+
self.calls = 0
|
116
|
+
|
117
|
+
def __str__(self) -> str:
|
118
|
+
return (
|
119
|
+
f"Tokens = "
|
120
|
+
f"(prompt {self.prompt_tokens}, completion {self.completion_tokens}), "
|
121
|
+
f"Cost={self.cost}, Calls={self.calls}"
|
122
|
+
)
|
82
123
|
|
83
124
|
@property
|
84
125
|
def total_tokens(self) -> int:
|
@@ -99,12 +140,16 @@ class LLMMessage(BaseModel):
|
|
99
140
|
|
100
141
|
role: Role
|
101
142
|
name: Optional[str] = None
|
143
|
+
tool_id: str = "" # used by OpenAIAssistant
|
102
144
|
content: str
|
103
145
|
function_call: Optional[LLMFunctionCall] = None
|
146
|
+
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
104
147
|
|
105
148
|
def api_dict(self) -> Dict[str, Any]:
|
106
149
|
"""
|
107
150
|
Convert to dictionary for API request.
|
151
|
+
DROP the tool_id, since it is only for use in the Assistant API,
|
152
|
+
not the completion API.
|
108
153
|
Returns:
|
109
154
|
dict: dictionary representation of LLM message
|
110
155
|
"""
|
@@ -120,6 +165,8 @@ class LLMMessage(BaseModel):
|
|
120
165
|
dict_no_none["function_call"]["arguments"] = json.dumps(
|
121
166
|
dict_no_none["function_call"]["arguments"]
|
122
167
|
)
|
168
|
+
dict_no_none.pop("tool_id", None)
|
169
|
+
dict_no_none.pop("timestamp", None)
|
123
170
|
return dict_no_none
|
124
171
|
|
125
172
|
def __str__(self) -> str:
|
@@ -137,10 +184,17 @@ class LLMResponse(BaseModel):
|
|
137
184
|
"""
|
138
185
|
|
139
186
|
message: str
|
187
|
+
tool_id: str = "" # used by OpenAIAssistant
|
140
188
|
function_call: Optional[LLMFunctionCall] = None
|
141
189
|
usage: Optional[LLMTokenUsage]
|
142
190
|
cached: bool = False
|
143
191
|
|
192
|
+
def __str__(self) -> str:
|
193
|
+
if self.function_call is not None:
|
194
|
+
return str(self.function_call)
|
195
|
+
else:
|
196
|
+
return self.message
|
197
|
+
|
144
198
|
def to_LLMMessage(self) -> LLMMessage:
|
145
199
|
content = self.message
|
146
200
|
role = Role.ASSISTANT if self.function_call is None else Role.FUNCTION
|
@@ -204,7 +258,10 @@ class LanguageModel(ABC):
|
|
204
258
|
Abstract base class for language models.
|
205
259
|
"""
|
206
260
|
|
207
|
-
|
261
|
+
# usage cost by model, accumulates here
|
262
|
+
usage_cost_dict: Dict[str, LLMTokenUsage] = {}
|
263
|
+
|
264
|
+
def __init__(self, config: LLMConfig = LLMConfig()):
|
208
265
|
self.config = config
|
209
266
|
|
210
267
|
@staticmethod
|
@@ -215,6 +272,16 @@ class LanguageModel(ABC):
|
|
215
272
|
config: configuration for language model
|
216
273
|
Returns: instance of language model
|
217
274
|
"""
|
275
|
+
if type(config) is LLMConfig:
|
276
|
+
raise ValueError(
|
277
|
+
"""
|
278
|
+
Cannot create a Language Model object from LLMConfig.
|
279
|
+
Please specify a specific subclass of LLMConfig e.g.,
|
280
|
+
OpenAIGPTConfig. If you are creating a ChatAgent from
|
281
|
+
a ChatAgentConfig, please specify the `llm` field of this config
|
282
|
+
as a specific subclass of LLMConfig, e.g., OpenAIGPTConfig.
|
283
|
+
"""
|
284
|
+
)
|
218
285
|
from langroid.language_models.azure_openai import AzureGPT
|
219
286
|
from langroid.language_models.openai_gpt import OpenAIGPT
|
220
287
|
|
@@ -311,18 +378,18 @@ class LanguageModel(ABC):
|
|
311
378
|
pass
|
312
379
|
|
313
380
|
@abstractmethod
|
314
|
-
def generate(self, prompt: str, max_tokens: int) -> LLMResponse:
|
381
|
+
def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
315
382
|
pass
|
316
383
|
|
317
384
|
@abstractmethod
|
318
|
-
async def agenerate(self, prompt: str, max_tokens: int) -> LLMResponse:
|
385
|
+
async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
319
386
|
pass
|
320
387
|
|
321
388
|
@abstractmethod
|
322
389
|
def chat(
|
323
390
|
self,
|
324
391
|
messages: Union[str, List[LLMMessage]],
|
325
|
-
max_tokens: int,
|
392
|
+
max_tokens: int = 200,
|
326
393
|
functions: Optional[List[LLMFunctionSpec]] = None,
|
327
394
|
function_call: str | Dict[str, str] = "auto",
|
328
395
|
) -> LLMResponse:
|
@@ -332,7 +399,7 @@ class LanguageModel(ABC):
|
|
332
399
|
async def achat(
|
333
400
|
self,
|
334
401
|
messages: Union[str, List[LLMMessage]],
|
335
|
-
max_tokens: int,
|
402
|
+
max_tokens: int = 200,
|
336
403
|
functions: Optional[List[LLMFunctionSpec]] = None,
|
337
404
|
function_call: str | Dict[str, str] = "auto",
|
338
405
|
) -> LLMResponse:
|
@@ -350,6 +417,44 @@ class LanguageModel(ABC):
|
|
350
417
|
def chat_cost(self) -> Tuple[float, float]:
|
351
418
|
return self.config.chat_cost_per_1k_tokens
|
352
419
|
|
420
|
+
def reset_usage_cost(self) -> None:
|
421
|
+
for mdl in [self.config.chat_model, self.config.completion_model]:
|
422
|
+
if mdl is None:
|
423
|
+
return
|
424
|
+
if mdl not in self.usage_cost_dict:
|
425
|
+
self.usage_cost_dict[mdl] = LLMTokenUsage()
|
426
|
+
counter = self.usage_cost_dict[mdl]
|
427
|
+
counter.reset()
|
428
|
+
|
429
|
+
def update_usage_cost(
|
430
|
+
self, chat: bool, prompts: int, completions: int, cost: float
|
431
|
+
) -> None:
|
432
|
+
"""
|
433
|
+
Update usage cost for this LLM.
|
434
|
+
Args:
|
435
|
+
chat (bool): whether to update for chat or completion model
|
436
|
+
prompts (int): number of tokens used for prompts
|
437
|
+
completions (int): number of tokens used for completions
|
438
|
+
cost (float): total token cost in USD
|
439
|
+
"""
|
440
|
+
mdl = self.config.chat_model if chat else self.config.completion_model
|
441
|
+
if mdl is None:
|
442
|
+
return
|
443
|
+
if mdl not in self.usage_cost_dict:
|
444
|
+
self.usage_cost_dict[mdl] = LLMTokenUsage()
|
445
|
+
counter = self.usage_cost_dict[mdl]
|
446
|
+
counter.prompt_tokens += prompts
|
447
|
+
counter.completion_tokens += completions
|
448
|
+
counter.cost += cost
|
449
|
+
counter.calls += 1
|
450
|
+
|
451
|
+
@classmethod
|
452
|
+
def usage_cost_summary(cls) -> str:
|
453
|
+
s = ""
|
454
|
+
for model, counter in cls.usage_cost_dict.items():
|
455
|
+
s += f"{model}: {counter}\n"
|
456
|
+
return s
|
457
|
+
|
353
458
|
def followup_to_standalone(
|
354
459
|
self, chat_history: List[Tuple[str, str]], question: str
|
355
460
|
) -> str:
|