langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.6.dist-info/RECORD +0 -7
- langroid-0.33.6.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,678 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from datetime import datetime
|
5
|
+
from enum import Enum
|
6
|
+
from typing import (
|
7
|
+
Any,
|
8
|
+
Awaitable,
|
9
|
+
Callable,
|
10
|
+
Dict,
|
11
|
+
List,
|
12
|
+
Literal,
|
13
|
+
Optional,
|
14
|
+
Tuple,
|
15
|
+
Type,
|
16
|
+
Union,
|
17
|
+
cast,
|
18
|
+
)
|
19
|
+
|
20
|
+
from langroid.cachedb.base import CacheDBConfig
|
21
|
+
from langroid.cachedb.redis_cachedb import RedisCacheConfig
|
22
|
+
from langroid.parsing.agent_chats import parse_message
|
23
|
+
from langroid.parsing.parse_json import parse_imperfect_json, top_level_json_field
|
24
|
+
from langroid.prompts.dialog import collate_chat_history
|
25
|
+
from langroid.pydantic_v1 import BaseModel, BaseSettings, Field
|
26
|
+
from langroid.utils.configuration import settings
|
27
|
+
from langroid.utils.output.printing import show_if_debug
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
async def async_noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
FunctionCallTypes = Literal["none", "auto"]
|
41
|
+
ToolChoiceTypes = Literal["none", "auto", "required"]
|
42
|
+
ToolTypes = Literal["function"]
|
43
|
+
|
44
|
+
|
45
|
+
class LLMConfig(BaseSettings):
|
46
|
+
"""
|
47
|
+
Common configuration for all language models.
|
48
|
+
"""
|
49
|
+
|
50
|
+
type: str = "openai"
|
51
|
+
streamer: Optional[Callable[[Any], None]] = noop_fn
|
52
|
+
streamer_async: Optional[Callable[..., Awaitable[None]]] = async_noop_fn
|
53
|
+
api_base: str | None = None
|
54
|
+
formatter: None | str = None
|
55
|
+
timeout: int = 20 # timeout for API requests
|
56
|
+
chat_model: str = ""
|
57
|
+
completion_model: str = ""
|
58
|
+
temperature: float = 0.0
|
59
|
+
chat_context_length: int = 8000
|
60
|
+
async_stream_quiet: bool = True # suppress streaming output in async mode?
|
61
|
+
completion_context_length: int = 8000
|
62
|
+
max_output_tokens: int = 1024 # generate at most this many tokens
|
63
|
+
# if input length + max_output_tokens > context length of model,
|
64
|
+
# we will try shortening requested output
|
65
|
+
min_output_tokens: int = 64
|
66
|
+
use_completion_for_chat: bool = False # use completion model for chat?
|
67
|
+
# use chat model for completion? For OpenAI models, this MUST be set to True!
|
68
|
+
use_chat_for_completion: bool = True
|
69
|
+
stream: bool = True # stream output from API?
|
70
|
+
cache_config: None | CacheDBConfig = RedisCacheConfig()
|
71
|
+
|
72
|
+
# Dict of model -> (input/prompt cost, output/completion cost)
|
73
|
+
chat_cost_per_1k_tokens: Tuple[float, float] = (0.0, 0.0)
|
74
|
+
completion_cost_per_1k_tokens: Tuple[float, float] = (0.0, 0.0)
|
75
|
+
|
76
|
+
|
77
|
+
class LLMFunctionCall(BaseModel):
|
78
|
+
"""
|
79
|
+
Structure of LLM response indicating it "wants" to call a function.
|
80
|
+
Modeled after OpenAI spec for `function_call` field in ChatCompletion API.
|
81
|
+
"""
|
82
|
+
|
83
|
+
name: str # name of function to call
|
84
|
+
arguments: Optional[Dict[str, Any]] = None
|
85
|
+
|
86
|
+
@staticmethod
|
87
|
+
def from_dict(message: Dict[str, Any]) -> "LLMFunctionCall":
|
88
|
+
"""
|
89
|
+
Initialize from dictionary.
|
90
|
+
Args:
|
91
|
+
d: dictionary containing fields to initialize
|
92
|
+
"""
|
93
|
+
fun_call = LLMFunctionCall(name=message["name"])
|
94
|
+
fun_args_str = message["arguments"]
|
95
|
+
# sometimes may be malformed with invalid indents,
|
96
|
+
# so we try to be safe by removing newlines.
|
97
|
+
if fun_args_str is not None:
|
98
|
+
fun_args_str = fun_args_str.replace("\n", "").strip()
|
99
|
+
dict_or_list = parse_imperfect_json(fun_args_str)
|
100
|
+
|
101
|
+
if not isinstance(dict_or_list, dict):
|
102
|
+
raise ValueError(
|
103
|
+
f"""
|
104
|
+
Invalid function args: {fun_args_str}
|
105
|
+
parsed as {dict_or_list},
|
106
|
+
which is not a valid dict.
|
107
|
+
"""
|
108
|
+
)
|
109
|
+
fun_args = dict_or_list
|
110
|
+
else:
|
111
|
+
fun_args = None
|
112
|
+
fun_call.arguments = fun_args
|
113
|
+
|
114
|
+
return fun_call
|
115
|
+
|
116
|
+
def __str__(self) -> str:
|
117
|
+
return "FUNC: " + json.dumps(self.dict(), indent=2)
|
118
|
+
|
119
|
+
|
120
|
+
class LLMFunctionSpec(BaseModel):
|
121
|
+
"""
|
122
|
+
Description of a function available for the LLM to use.
|
123
|
+
To be used when calling the LLM `chat()` method with the `functions` parameter.
|
124
|
+
Modeled after OpenAI spec for `functions` fields in ChatCompletion API.
|
125
|
+
"""
|
126
|
+
|
127
|
+
name: str
|
128
|
+
description: str
|
129
|
+
parameters: Dict[str, Any]
|
130
|
+
|
131
|
+
|
132
|
+
class OpenAIToolCall(BaseModel):
|
133
|
+
"""
|
134
|
+
Represents a single tool call in a list of tool calls generated by OpenAI LLM API.
|
135
|
+
See https://platform.openai.com/docs/api-reference/chat/create
|
136
|
+
|
137
|
+
Attributes:
|
138
|
+
id: The id of the tool call.
|
139
|
+
type: The type of the tool call;
|
140
|
+
only "function" is currently possible (7/26/24).
|
141
|
+
function: The function call.
|
142
|
+
"""
|
143
|
+
|
144
|
+
id: str | None = None
|
145
|
+
type: ToolTypes = "function"
|
146
|
+
function: LLMFunctionCall | None = None
|
147
|
+
|
148
|
+
@staticmethod
|
149
|
+
def from_dict(message: Dict[str, Any]) -> "OpenAIToolCall":
|
150
|
+
"""
|
151
|
+
Initialize from dictionary.
|
152
|
+
Args:
|
153
|
+
d: dictionary containing fields to initialize
|
154
|
+
"""
|
155
|
+
id = message["id"]
|
156
|
+
type = message["type"]
|
157
|
+
function = LLMFunctionCall.from_dict(message["function"])
|
158
|
+
return OpenAIToolCall(id=id, type=type, function=function)
|
159
|
+
|
160
|
+
def __str__(self) -> str:
|
161
|
+
if self.function is None:
|
162
|
+
return ""
|
163
|
+
return "OAI-TOOL: " + json.dumps(self.function.dict(), indent=2)
|
164
|
+
|
165
|
+
|
166
|
+
class OpenAIToolSpec(BaseModel):
|
167
|
+
type: ToolTypes
|
168
|
+
strict: Optional[bool] = None
|
169
|
+
function: LLMFunctionSpec
|
170
|
+
|
171
|
+
|
172
|
+
class OpenAIJsonSchemaSpec(BaseModel):
|
173
|
+
strict: Optional[bool] = None
|
174
|
+
function: LLMFunctionSpec
|
175
|
+
|
176
|
+
def to_dict(self) -> Dict[str, Any]:
|
177
|
+
json_schema: Dict[str, Any] = {
|
178
|
+
"name": self.function.name,
|
179
|
+
"description": self.function.description,
|
180
|
+
"schema": self.function.parameters,
|
181
|
+
}
|
182
|
+
if self.strict is not None:
|
183
|
+
json_schema["strict"] = self.strict
|
184
|
+
|
185
|
+
return {
|
186
|
+
"type": "json_schema",
|
187
|
+
"json_schema": json_schema,
|
188
|
+
}
|
189
|
+
|
190
|
+
|
191
|
+
class LLMTokenUsage(BaseModel):
|
192
|
+
"""
|
193
|
+
Usage of tokens by an LLM.
|
194
|
+
"""
|
195
|
+
|
196
|
+
prompt_tokens: int = 0
|
197
|
+
completion_tokens: int = 0
|
198
|
+
cost: float = 0.0
|
199
|
+
calls: int = 0 # how many API calls
|
200
|
+
|
201
|
+
def reset(self) -> None:
|
202
|
+
self.prompt_tokens = 0
|
203
|
+
self.completion_tokens = 0
|
204
|
+
self.cost = 0.0
|
205
|
+
self.calls = 0
|
206
|
+
|
207
|
+
def __str__(self) -> str:
|
208
|
+
return (
|
209
|
+
f"Tokens = "
|
210
|
+
f"(prompt {self.prompt_tokens}, completion {self.completion_tokens}), "
|
211
|
+
f"Cost={self.cost}, Calls={self.calls}"
|
212
|
+
)
|
213
|
+
|
214
|
+
@property
|
215
|
+
def total_tokens(self) -> int:
|
216
|
+
return self.prompt_tokens + self.completion_tokens
|
217
|
+
|
218
|
+
|
219
|
+
class Role(str, Enum):
|
220
|
+
"""
|
221
|
+
Possible roles for a message in a chat.
|
222
|
+
"""
|
223
|
+
|
224
|
+
USER = "user"
|
225
|
+
SYSTEM = "system"
|
226
|
+
ASSISTANT = "assistant"
|
227
|
+
FUNCTION = "function"
|
228
|
+
TOOL = "tool"
|
229
|
+
|
230
|
+
|
231
|
+
class LLMMessage(BaseModel):
|
232
|
+
"""
|
233
|
+
Class representing an entry in the msg-history sent to the LLM API.
|
234
|
+
It could be one of these:
|
235
|
+
- a user message
|
236
|
+
- an LLM ("Assistant") response
|
237
|
+
- a fn-call or tool-call-list from an OpenAI-compatible LLM API response
|
238
|
+
- a result or results from executing a fn or tool-call(s)
|
239
|
+
"""
|
240
|
+
|
241
|
+
role: Role
|
242
|
+
name: Optional[str] = None
|
243
|
+
tool_call_id: Optional[str] = None # which OpenAI LLM tool this is a response to
|
244
|
+
tool_id: str = "" # used by OpenAIAssistant
|
245
|
+
content: str
|
246
|
+
function_call: Optional[LLMFunctionCall] = None
|
247
|
+
tool_calls: Optional[List[OpenAIToolCall]] = None
|
248
|
+
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
249
|
+
# link to corresponding chat document, for provenance/rewind purposes
|
250
|
+
chat_document_id: str = ""
|
251
|
+
|
252
|
+
def api_dict(self, has_system_role: bool = True) -> Dict[str, Any]:
|
253
|
+
"""
|
254
|
+
Convert to dictionary for API request, keeping ONLY
|
255
|
+
the fields that are expected in an API call!
|
256
|
+
E.g., DROP the tool_id, since it is only for use in the Assistant API,
|
257
|
+
not the completion API.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
has_system_role: whether the message has a system role (if not,
|
261
|
+
set to "user" role)
|
262
|
+
Returns:
|
263
|
+
dict: dictionary representation of LLM message
|
264
|
+
"""
|
265
|
+
d = self.dict()
|
266
|
+
# if there is a key k = "role" with value "system", change to "user"
|
267
|
+
# in case has_system_role is False
|
268
|
+
if not has_system_role and "role" in d and d["role"] == "system":
|
269
|
+
d["role"] = "user"
|
270
|
+
if "content" in d:
|
271
|
+
d["content"] = "[ADDITIONAL SYSTEM MESSAGE:]\n\n" + d["content"]
|
272
|
+
# drop None values since API doesn't accept them
|
273
|
+
dict_no_none = {k: v for k, v in d.items() if v is not None}
|
274
|
+
if "name" in dict_no_none and dict_no_none["name"] == "":
|
275
|
+
# OpenAI API does not like empty name
|
276
|
+
del dict_no_none["name"]
|
277
|
+
if "function_call" in dict_no_none:
|
278
|
+
# arguments must be a string
|
279
|
+
if "arguments" in dict_no_none["function_call"]:
|
280
|
+
dict_no_none["function_call"]["arguments"] = json.dumps(
|
281
|
+
dict_no_none["function_call"]["arguments"]
|
282
|
+
)
|
283
|
+
if "tool_calls" in dict_no_none:
|
284
|
+
# convert tool calls to API format
|
285
|
+
for tc in dict_no_none["tool_calls"]:
|
286
|
+
if "arguments" in tc["function"]:
|
287
|
+
# arguments must be a string
|
288
|
+
tc["function"]["arguments"] = json.dumps(
|
289
|
+
tc["function"]["arguments"]
|
290
|
+
)
|
291
|
+
# IMPORTANT! drop fields that are not expected in API call
|
292
|
+
dict_no_none.pop("tool_id", None)
|
293
|
+
dict_no_none.pop("timestamp", None)
|
294
|
+
dict_no_none.pop("chat_document_id", None)
|
295
|
+
return dict_no_none
|
296
|
+
|
297
|
+
def __str__(self) -> str:
|
298
|
+
if self.function_call is not None:
|
299
|
+
content = "FUNC: " + json.dumps(self.function_call)
|
300
|
+
else:
|
301
|
+
content = self.content
|
302
|
+
name_str = f" ({self.name})" if self.name else ""
|
303
|
+
return f"{self.role} {name_str}: {content}"
|
304
|
+
|
305
|
+
|
306
|
+
class LLMResponse(BaseModel):
|
307
|
+
"""
|
308
|
+
Class representing response from LLM.
|
309
|
+
"""
|
310
|
+
|
311
|
+
message: str
|
312
|
+
# TODO tool_id needs to generalize to multi-tool calls
|
313
|
+
tool_id: str = "" # used by OpenAIAssistant
|
314
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None
|
315
|
+
function_call: Optional[LLMFunctionCall] = None
|
316
|
+
usage: Optional[LLMTokenUsage] = None
|
317
|
+
cached: bool = False
|
318
|
+
|
319
|
+
def __str__(self) -> str:
|
320
|
+
if self.function_call is not None:
|
321
|
+
return str(self.function_call)
|
322
|
+
elif self.oai_tool_calls:
|
323
|
+
return "\n".join(str(tc) for tc in self.oai_tool_calls)
|
324
|
+
else:
|
325
|
+
return self.message
|
326
|
+
|
327
|
+
def to_LLMMessage(self) -> LLMMessage:
|
328
|
+
"""Convert LLM response to an LLMMessage, to be included in the
|
329
|
+
message-list sent to the API.
|
330
|
+
This is currently NOT used in any significant way in the library, and is only
|
331
|
+
provided as a utility to construct a message list for the API when directly
|
332
|
+
working with an LLM object.
|
333
|
+
|
334
|
+
In a `ChatAgent`, an LLM response is first converted to a ChatDocument,
|
335
|
+
which is in turn converted to an LLMMessage via `ChatDocument.to_LLMMessage()`
|
336
|
+
See `ChatAgent._prep_llm_messages()` and `ChatAgent.llm_response_messages`
|
337
|
+
"""
|
338
|
+
return LLMMessage(
|
339
|
+
role=Role.ASSISTANT,
|
340
|
+
content=self.message,
|
341
|
+
name=None if self.function_call is None else self.function_call.name,
|
342
|
+
function_call=self.function_call,
|
343
|
+
tool_calls=self.oai_tool_calls,
|
344
|
+
)
|
345
|
+
|
346
|
+
def get_recipient_and_message(
|
347
|
+
self,
|
348
|
+
) -> Tuple[str, str]:
|
349
|
+
"""
|
350
|
+
If `message` or `function_call` of an LLM response contains an explicit
|
351
|
+
recipient name, return this recipient name and `message` stripped
|
352
|
+
of the recipient name if specified.
|
353
|
+
|
354
|
+
Two cases:
|
355
|
+
(a) `message` contains addressing string "TO: <name> <content>", or
|
356
|
+
(b) `message` is empty and function_call/tool_call with explicit `recipient`
|
357
|
+
|
358
|
+
|
359
|
+
Returns:
|
360
|
+
(str): name of recipient, which may be empty string if no recipient
|
361
|
+
(str): content of message
|
362
|
+
|
363
|
+
"""
|
364
|
+
|
365
|
+
if self.function_call is not None:
|
366
|
+
# in this case we ignore message, since all information is in function_call
|
367
|
+
msg = ""
|
368
|
+
args = self.function_call.arguments
|
369
|
+
recipient = ""
|
370
|
+
if isinstance(args, dict):
|
371
|
+
recipient = args.get("recipient", "")
|
372
|
+
return recipient, msg
|
373
|
+
else:
|
374
|
+
msg = self.message
|
375
|
+
if self.oai_tool_calls is not None:
|
376
|
+
# get the first tool that has a recipient field, if any
|
377
|
+
for tc in self.oai_tool_calls:
|
378
|
+
if tc.function is not None and tc.function.arguments is not None:
|
379
|
+
recipient = tc.function.arguments.get(
|
380
|
+
"recipient"
|
381
|
+
) # type: ignore
|
382
|
+
if recipient is not None and recipient != "":
|
383
|
+
return recipient, ""
|
384
|
+
|
385
|
+
# It's not a function or tool call, so continue looking to see
|
386
|
+
# if a recipient is specified in the message.
|
387
|
+
|
388
|
+
# First check if message contains "TO: <recipient> <content>"
|
389
|
+
recipient_name, content = parse_message(msg) if msg is not None else ("", "")
|
390
|
+
# check if there is a top level json that specifies 'recipient',
|
391
|
+
# and retain the entire message as content.
|
392
|
+
if recipient_name == "":
|
393
|
+
recipient_name = top_level_json_field(msg, "recipient") if msg else ""
|
394
|
+
content = msg
|
395
|
+
return recipient_name, content
|
396
|
+
|
397
|
+
|
398
|
+
# Define an abstract base class for language models
|
399
|
+
class LanguageModel(ABC):
|
400
|
+
"""
|
401
|
+
Abstract base class for language models.
|
402
|
+
"""
|
403
|
+
|
404
|
+
# usage cost by model, accumulates here
|
405
|
+
usage_cost_dict: Dict[str, LLMTokenUsage] = {}
|
406
|
+
|
407
|
+
def __init__(self, config: LLMConfig = LLMConfig()):
|
408
|
+
self.config = config
|
409
|
+
|
410
|
+
@staticmethod
|
411
|
+
def create(config: Optional[LLMConfig]) -> Optional["LanguageModel"]:
|
412
|
+
"""
|
413
|
+
Create a language model.
|
414
|
+
Args:
|
415
|
+
config: configuration for language model
|
416
|
+
Returns: instance of language model
|
417
|
+
"""
|
418
|
+
if type(config) is LLMConfig:
|
419
|
+
raise ValueError(
|
420
|
+
"""
|
421
|
+
Cannot create a Language Model object from LLMConfig.
|
422
|
+
Please specify a specific subclass of LLMConfig e.g.,
|
423
|
+
OpenAIGPTConfig. If you are creating a ChatAgent from
|
424
|
+
a ChatAgentConfig, please specify the `llm` field of this config
|
425
|
+
as a specific subclass of LLMConfig, e.g., OpenAIGPTConfig.
|
426
|
+
"""
|
427
|
+
)
|
428
|
+
from langroid.language_models.azure_openai import AzureGPT
|
429
|
+
from langroid.language_models.mock_lm import MockLM, MockLMConfig
|
430
|
+
from langroid.language_models.openai_gpt import OpenAIGPT
|
431
|
+
|
432
|
+
if config is None or config.type is None:
|
433
|
+
return None
|
434
|
+
|
435
|
+
if config.type == "mock":
|
436
|
+
return MockLM(cast(MockLMConfig, config))
|
437
|
+
|
438
|
+
openai: Union[Type[AzureGPT], Type[OpenAIGPT]]
|
439
|
+
|
440
|
+
if config.type == "azure":
|
441
|
+
openai = AzureGPT
|
442
|
+
else:
|
443
|
+
openai = OpenAIGPT
|
444
|
+
cls = dict(
|
445
|
+
openai=openai,
|
446
|
+
).get(config.type, openai)
|
447
|
+
return cls(config) # type: ignore
|
448
|
+
|
449
|
+
@staticmethod
|
450
|
+
def user_assistant_pairs(lst: List[str]) -> List[Tuple[str, str]]:
|
451
|
+
"""
|
452
|
+
Given an even-length sequence of strings, split into a sequence of pairs
|
453
|
+
|
454
|
+
Args:
|
455
|
+
lst (List[str]): sequence of strings
|
456
|
+
|
457
|
+
Returns:
|
458
|
+
List[Tuple[str,str]]: sequence of pairs of strings
|
459
|
+
"""
|
460
|
+
evens = lst[::2]
|
461
|
+
odds = lst[1::2]
|
462
|
+
return list(zip(evens, odds))
|
463
|
+
|
464
|
+
@staticmethod
|
465
|
+
def get_chat_history_components(
|
466
|
+
messages: List[LLMMessage],
|
467
|
+
) -> Tuple[str, List[Tuple[str, str]], str]:
|
468
|
+
"""
|
469
|
+
From the chat history, extract system prompt, user-assistant turns, and
|
470
|
+
final user msg.
|
471
|
+
|
472
|
+
Args:
|
473
|
+
messages (List[LLMMessage]): List of messages in the chat history
|
474
|
+
|
475
|
+
Returns:
|
476
|
+
Tuple[str, List[Tuple[str,str]], str]:
|
477
|
+
system prompt, user-assistant turns, final user msg
|
478
|
+
|
479
|
+
"""
|
480
|
+
# Handle various degenerate cases
|
481
|
+
messages = [m for m in messages] # copy
|
482
|
+
DUMMY_SYS_PROMPT = "You are a helpful assistant."
|
483
|
+
DUMMY_USER_PROMPT = "Follow the instructions above."
|
484
|
+
if len(messages) == 0 or messages[0].role != Role.SYSTEM:
|
485
|
+
logger.warning("No system msg, creating dummy system prompt")
|
486
|
+
messages.insert(0, LLMMessage(content=DUMMY_SYS_PROMPT, role=Role.SYSTEM))
|
487
|
+
system_prompt = messages[0].content
|
488
|
+
|
489
|
+
# now we have messages = [Sys,...]
|
490
|
+
if len(messages) == 1:
|
491
|
+
logger.warning(
|
492
|
+
"Got only system message in chat history, creating dummy user prompt"
|
493
|
+
)
|
494
|
+
messages.append(LLMMessage(content=DUMMY_USER_PROMPT, role=Role.USER))
|
495
|
+
|
496
|
+
# now we have messages = [Sys, msg, ...]
|
497
|
+
|
498
|
+
if messages[1].role != Role.USER:
|
499
|
+
messages.insert(1, LLMMessage(content=DUMMY_USER_PROMPT, role=Role.USER))
|
500
|
+
|
501
|
+
# now we have messages = [Sys, user, ...]
|
502
|
+
if messages[-1].role != Role.USER:
|
503
|
+
logger.warning(
|
504
|
+
"Last message in chat history is not a user message,"
|
505
|
+
" creating dummy user prompt"
|
506
|
+
)
|
507
|
+
messages.append(LLMMessage(content=DUMMY_USER_PROMPT, role=Role.USER))
|
508
|
+
|
509
|
+
# now we have messages = [Sys, user, ..., user]
|
510
|
+
# so we omit the first and last elements and make pairs of user-asst messages
|
511
|
+
conversation = [m.content for m in messages[1:-1]]
|
512
|
+
user_prompt = messages[-1].content
|
513
|
+
pairs = LanguageModel.user_assistant_pairs(conversation)
|
514
|
+
return system_prompt, pairs, user_prompt
|
515
|
+
|
516
|
+
@abstractmethod
|
517
|
+
def set_stream(self, stream: bool) -> bool:
|
518
|
+
"""Enable or disable streaming output from API.
|
519
|
+
Return previous value of stream."""
|
520
|
+
pass
|
521
|
+
|
522
|
+
@abstractmethod
|
523
|
+
def get_stream(self) -> bool:
|
524
|
+
"""Get streaming status"""
|
525
|
+
pass
|
526
|
+
|
527
|
+
@abstractmethod
|
528
|
+
def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
529
|
+
pass
|
530
|
+
|
531
|
+
@abstractmethod
|
532
|
+
async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
|
533
|
+
pass
|
534
|
+
|
535
|
+
@abstractmethod
|
536
|
+
def chat(
|
537
|
+
self,
|
538
|
+
messages: Union[str, List[LLMMessage]],
|
539
|
+
max_tokens: int = 200,
|
540
|
+
tools: Optional[List[OpenAIToolSpec]] = None,
|
541
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
542
|
+
functions: Optional[List[LLMFunctionSpec]] = None,
|
543
|
+
function_call: str | Dict[str, str] = "auto",
|
544
|
+
response_format: Optional[OpenAIJsonSchemaSpec] = None,
|
545
|
+
) -> LLMResponse:
|
546
|
+
"""
|
547
|
+
Get chat-completion response from LLM.
|
548
|
+
|
549
|
+
Args:
|
550
|
+
messages: message-history to send to the LLM
|
551
|
+
max_tokens: max tokens to generate
|
552
|
+
tools: tools available for the LLM to use in its response
|
553
|
+
tool_choice: tool call mode, one of "none", "auto", "required",
|
554
|
+
or a dict specifying a specific tool.
|
555
|
+
functions: functions available for LLM to call (deprecated)
|
556
|
+
function_call: function calling mode, "auto", "none", or a specific fn
|
557
|
+
(deprecated)
|
558
|
+
"""
|
559
|
+
|
560
|
+
pass
|
561
|
+
|
562
|
+
@abstractmethod
|
563
|
+
async def achat(
|
564
|
+
self,
|
565
|
+
messages: Union[str, List[LLMMessage]],
|
566
|
+
max_tokens: int = 200,
|
567
|
+
tools: Optional[List[OpenAIToolSpec]] = None,
|
568
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
569
|
+
functions: Optional[List[LLMFunctionSpec]] = None,
|
570
|
+
function_call: str | Dict[str, str] = "auto",
|
571
|
+
response_format: Optional[OpenAIJsonSchemaSpec] = None,
|
572
|
+
) -> LLMResponse:
|
573
|
+
"""Async version of `chat`. See `chat` for details."""
|
574
|
+
pass
|
575
|
+
|
576
|
+
def __call__(self, prompt: str, max_tokens: int) -> LLMResponse:
|
577
|
+
return self.generate(prompt, max_tokens)
|
578
|
+
|
579
|
+
def chat_context_length(self) -> int:
|
580
|
+
return self.config.chat_context_length
|
581
|
+
|
582
|
+
def completion_context_length(self) -> int:
|
583
|
+
return self.config.completion_context_length
|
584
|
+
|
585
|
+
def chat_cost(self) -> Tuple[float, float]:
|
586
|
+
return self.config.chat_cost_per_1k_tokens
|
587
|
+
|
588
|
+
def reset_usage_cost(self) -> None:
|
589
|
+
for mdl in [self.config.chat_model, self.config.completion_model]:
|
590
|
+
if mdl is None:
|
591
|
+
return
|
592
|
+
if mdl not in self.usage_cost_dict:
|
593
|
+
self.usage_cost_dict[mdl] = LLMTokenUsage()
|
594
|
+
counter = self.usage_cost_dict[mdl]
|
595
|
+
counter.reset()
|
596
|
+
|
597
|
+
def update_usage_cost(
|
598
|
+
self, chat: bool, prompts: int, completions: int, cost: float
|
599
|
+
) -> None:
|
600
|
+
"""
|
601
|
+
Update usage cost for this LLM.
|
602
|
+
Args:
|
603
|
+
chat (bool): whether to update for chat or completion model
|
604
|
+
prompts (int): number of tokens used for prompts
|
605
|
+
completions (int): number of tokens used for completions
|
606
|
+
cost (float): total token cost in USD
|
607
|
+
"""
|
608
|
+
mdl = self.config.chat_model if chat else self.config.completion_model
|
609
|
+
if mdl is None:
|
610
|
+
return
|
611
|
+
if mdl not in self.usage_cost_dict:
|
612
|
+
self.usage_cost_dict[mdl] = LLMTokenUsage()
|
613
|
+
counter = self.usage_cost_dict[mdl]
|
614
|
+
counter.prompt_tokens += prompts
|
615
|
+
counter.completion_tokens += completions
|
616
|
+
counter.cost += cost
|
617
|
+
counter.calls += 1
|
618
|
+
|
619
|
+
@classmethod
|
620
|
+
def usage_cost_summary(cls) -> str:
|
621
|
+
s = ""
|
622
|
+
for model, counter in cls.usage_cost_dict.items():
|
623
|
+
s += f"{model}: {counter}\n"
|
624
|
+
return s
|
625
|
+
|
626
|
+
@classmethod
|
627
|
+
def tot_tokens_cost(cls) -> Tuple[int, float]:
|
628
|
+
"""
|
629
|
+
Return total tokens used and total cost across all models.
|
630
|
+
"""
|
631
|
+
total_tokens = 0
|
632
|
+
total_cost = 0.0
|
633
|
+
for counter in cls.usage_cost_dict.values():
|
634
|
+
total_tokens += counter.total_tokens
|
635
|
+
total_cost += counter.cost
|
636
|
+
return total_tokens, total_cost
|
637
|
+
|
638
|
+
def followup_to_standalone(
|
639
|
+
self, chat_history: List[Tuple[str, str]], question: str
|
640
|
+
) -> str:
|
641
|
+
"""
|
642
|
+
Given a chat history and a question, convert it to a standalone question.
|
643
|
+
Args:
|
644
|
+
chat_history: list of tuples of (question, answer)
|
645
|
+
query: follow-up question
|
646
|
+
|
647
|
+
Returns: standalone version of the question
|
648
|
+
"""
|
649
|
+
history = collate_chat_history(chat_history)
|
650
|
+
|
651
|
+
prompt = f"""
|
652
|
+
Given the CHAT HISTORY below, and a follow-up QUESTION or SEARCH PHRASE,
|
653
|
+
rephrase the follow-up question/phrase as a STANDALONE QUESTION that
|
654
|
+
can be understood without the context of the chat history.
|
655
|
+
|
656
|
+
Chat history: {history}
|
657
|
+
|
658
|
+
Follow-up question: {question}
|
659
|
+
""".strip()
|
660
|
+
show_if_debug(prompt, "FOLLOWUP->STANDALONE-PROMPT= ")
|
661
|
+
standalone = self.generate(prompt=prompt, max_tokens=1024).message.strip()
|
662
|
+
show_if_debug(prompt, "FOLLOWUP->STANDALONE-RESPONSE= ")
|
663
|
+
return standalone
|
664
|
+
|
665
|
+
|
666
|
+
class StreamingIfAllowed:
|
667
|
+
"""Context to temporarily enable or disable streaming, if allowed globally via
|
668
|
+
`settings.stream`"""
|
669
|
+
|
670
|
+
def __init__(self, llm: LanguageModel, stream: bool = True):
|
671
|
+
self.llm = llm
|
672
|
+
self.stream = stream
|
673
|
+
|
674
|
+
def __enter__(self) -> None:
|
675
|
+
self.old_stream = self.llm.set_stream(settings.stream and self.stream)
|
676
|
+
|
677
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
678
|
+
self.llm.set_stream(self.old_stream)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from langroid.pydantic_v1 import BaseSettings
|
2
|
+
|
3
|
+
|
4
|
+
class PromptFormatterConfig(BaseSettings):
|
5
|
+
type: str = "llama2"
|
6
|
+
|
7
|
+
class Config:
|
8
|
+
env_prefix = "FORMAT_"
|
9
|
+
case_sensitive = False
|
10
|
+
|
11
|
+
|
12
|
+
class Llama2FormatterConfig(PromptFormatterConfig):
|
13
|
+
use_bos_eos: bool = False
|
14
|
+
|
15
|
+
|
16
|
+
class HFPromptFormatterConfig(PromptFormatterConfig):
|
17
|
+
type: str = "hf"
|
18
|
+
model_name: str
|