langroid 0.1.234__py3-none-any.whl → 0.1.236__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +2 -0
- langroid/agent/batch.py +1 -0
- langroid/agent/chat_document.py +19 -0
- langroid/agent/openai_assistant.py +16 -13
- langroid/agent/special/table_chat_agent.py +0 -1
- langroid/agent/task.py +96 -13
- langroid/cachedb/redis_cachedb.py +1 -1
- langroid/language_models/openai_gpt.py +37 -12
- langroid/utils/constants.py +1 -1
- {langroid-0.1.234.dist-info → langroid-0.1.236.dist-info}/METADATA +15 -1
- {langroid-0.1.234.dist-info → langroid-0.1.236.dist-info}/RECORD +13 -13
- {langroid-0.1.234.dist-info → langroid-0.1.236.dist-info}/LICENSE +0 -0
- {langroid-0.1.234.dist-info → langroid-0.1.236.dist-info}/WHEEL +0 -0
langroid/__init__.py
CHANGED
@@ -27,6 +27,7 @@ from .agent.batch import (
|
|
27
27
|
)
|
28
28
|
|
29
29
|
from .agent.chat_document import (
|
30
|
+
StatusCode,
|
30
31
|
ChatDocument,
|
31
32
|
ChatDocMetaData,
|
32
33
|
)
|
@@ -77,6 +78,7 @@ __all__ = [
|
|
77
78
|
"AgentConfig",
|
78
79
|
"ChatAgent",
|
79
80
|
"ChatAgentConfig",
|
81
|
+
"StatusCode",
|
80
82
|
"ChatDocument",
|
81
83
|
"ChatDocMetaData",
|
82
84
|
"Task",
|
langroid/agent/batch.py
CHANGED
@@ -53,6 +53,7 @@ def run_batch_task_gen(
|
|
53
53
|
message (Optional[str]): optionally overrides the console status messages
|
54
54
|
handle_exceptions: bool: Whether to replace exceptions with outputs of None
|
55
55
|
max_cost: float: maximum cost to run the task (default 0.0 for unlimited)
|
56
|
+
max_tokens: int: maximum token usage (in and out) (default 0 for unlimited)
|
56
57
|
|
57
58
|
|
58
59
|
Returns:
|
langroid/agent/chat_document.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import json
|
2
|
+
from enum import Enum
|
2
3
|
from typing import List, Optional, Union
|
3
4
|
|
4
5
|
from pydantic import BaseModel, Extra
|
@@ -23,6 +24,23 @@ class ChatDocAttachment(BaseModel):
|
|
23
24
|
extra = Extra.allow
|
24
25
|
|
25
26
|
|
27
|
+
class StatusCode(str, Enum):
|
28
|
+
"""Codes meant to be returned by task.run(). Some are not used yet."""
|
29
|
+
|
30
|
+
OK = "OK"
|
31
|
+
ERROR = "ERROR"
|
32
|
+
DONE = "DONE"
|
33
|
+
STALLED = "STALLED"
|
34
|
+
INF_LOOP = "INF_LOOP"
|
35
|
+
KILL = "KILL"
|
36
|
+
MAX_TURNS = "MAX_TURNS"
|
37
|
+
MAX_COST = "MAX_COST"
|
38
|
+
MAX_TOKENS = "MAX_TOKENS"
|
39
|
+
TIMEOUT = "TIMEOUT"
|
40
|
+
NO_ANSWER = "NO_ANSWER"
|
41
|
+
USER_QUIT = "USER_QUIT"
|
42
|
+
|
43
|
+
|
26
44
|
class ChatDocMetaData(DocMetaData):
|
27
45
|
parent: Optional["ChatDocument"] = None
|
28
46
|
sender: Entity
|
@@ -35,6 +53,7 @@ class ChatDocMetaData(DocMetaData):
|
|
35
53
|
usage: Optional[LLMTokenUsage]
|
36
54
|
cached: bool = False
|
37
55
|
displayed: bool = False
|
56
|
+
status: Optional[StatusCode] = None
|
38
57
|
|
39
58
|
|
40
59
|
class ChatDocLoggerFields(BaseModel):
|
@@ -7,6 +7,7 @@ import time
|
|
7
7
|
from enum import Enum
|
8
8
|
from typing import Any, Dict, List, Optional, Tuple, Type, cast, no_type_check
|
9
9
|
|
10
|
+
import openai
|
10
11
|
from openai.types.beta import Assistant, Thread
|
11
12
|
from openai.types.beta.assistant_update_params import (
|
12
13
|
ToolResources,
|
@@ -99,12 +100,14 @@ class OpenAIAssistant(ChatAgent):
|
|
99
100
|
super().__init__(config)
|
100
101
|
self.config: OpenAIAssistantConfig = config
|
101
102
|
self.llm: OpenAIGPT = OpenAIGPT(self.config.llm)
|
103
|
+
if not isinstance(self.llm.client, openai.OpenAI):
|
104
|
+
raise ValueError("Client must be OpenAI")
|
102
105
|
# handles for various entities and methods
|
103
|
-
self.client = self.llm.client
|
104
|
-
self.runs = self.
|
105
|
-
self.threads = self.
|
106
|
-
self.thread_messages = self.
|
107
|
-
self.assistants = self.
|
106
|
+
self.client: openai.OpenAI = self.llm.client
|
107
|
+
self.runs = self.client.beta.threads.runs
|
108
|
+
self.threads = self.client.beta.threads
|
109
|
+
self.thread_messages = self.client.beta.threads.messages
|
110
|
+
self.assistants = self.client.beta.assistants
|
108
111
|
# which tool_ids are awaiting output submissions
|
109
112
|
self.pending_tool_ids: List[str] = []
|
110
113
|
self.cached_tool_ids: List[str] = []
|
@@ -208,14 +211,14 @@ class OpenAIAssistant(ChatAgent):
|
|
208
211
|
|
209
212
|
def _cache_thread_key(self) -> str:
|
210
213
|
"""Key to use for caching or retrieving thread id"""
|
211
|
-
org = self.
|
214
|
+
org = self.client.organization or ""
|
212
215
|
uid = generate_user_id(org)
|
213
216
|
name = self.config.name
|
214
217
|
return "Thread:" + name + ":" + uid
|
215
218
|
|
216
219
|
def _cache_assistant_key(self) -> str:
|
217
220
|
"""Key to use for caching or retrieving assistant id"""
|
218
|
-
org = self.
|
221
|
+
org = self.client.organization or ""
|
219
222
|
uid = generate_user_id(org)
|
220
223
|
name = self.config.name
|
221
224
|
return "Assistant:" + name + ":" + uid
|
@@ -317,7 +320,7 @@ class OpenAIAssistant(ChatAgent):
|
|
317
320
|
cached = self._cache_thread_lookup()
|
318
321
|
if cached is not None:
|
319
322
|
if self.config.use_cached_thread:
|
320
|
-
self.thread = self.
|
323
|
+
self.thread = self.client.beta.threads.retrieve(thread_id=cached)
|
321
324
|
else:
|
322
325
|
logger.warning(
|
323
326
|
f"""
|
@@ -326,7 +329,7 @@ class OpenAIAssistant(ChatAgent):
|
|
326
329
|
"""
|
327
330
|
)
|
328
331
|
try:
|
329
|
-
self.
|
332
|
+
self.client.beta.threads.delete(thread_id=cached)
|
330
333
|
except Exception:
|
331
334
|
logger.warning(
|
332
335
|
f"""
|
@@ -337,7 +340,7 @@ class OpenAIAssistant(ChatAgent):
|
|
337
340
|
if self.thread is None:
|
338
341
|
if self.assistant is None:
|
339
342
|
raise ValueError("Assistant is None")
|
340
|
-
self.thread = self.
|
343
|
+
self.thread = self.client.beta.threads.create()
|
341
344
|
hash_key_str = (
|
342
345
|
(self.assistant.instructions or "")
|
343
346
|
+ str(self.config.use_tools)
|
@@ -371,7 +374,7 @@ class OpenAIAssistant(ChatAgent):
|
|
371
374
|
cached = self._cache_assistant_lookup()
|
372
375
|
if cached is not None:
|
373
376
|
if self.config.use_cached_assistant:
|
374
|
-
self.assistant = self.
|
377
|
+
self.assistant = self.client.beta.assistants.retrieve(
|
375
378
|
assistant_id=cached
|
376
379
|
)
|
377
380
|
else:
|
@@ -382,7 +385,7 @@ class OpenAIAssistant(ChatAgent):
|
|
382
385
|
"""
|
383
386
|
)
|
384
387
|
try:
|
385
|
-
self.
|
388
|
+
self.client.beta.assistants.delete(assistant_id=cached)
|
386
389
|
except Exception:
|
387
390
|
logger.warning(
|
388
391
|
f"""
|
@@ -391,7 +394,7 @@ class OpenAIAssistant(ChatAgent):
|
|
391
394
|
)
|
392
395
|
self.llm.cache.delete_keys([self._cache_assistant_key()])
|
393
396
|
if self.assistant is None:
|
394
|
-
self.assistant = self.
|
397
|
+
self.assistant = self.client.beta.assistants.create(
|
395
398
|
name=self.config.name,
|
396
399
|
instructions=self.config.system_message,
|
397
400
|
tools=[],
|
langroid/agent/task.py
CHANGED
@@ -27,11 +27,20 @@ from langroid.agent.chat_document import (
|
|
27
27
|
ChatDocLoggerFields,
|
28
28
|
ChatDocMetaData,
|
29
29
|
ChatDocument,
|
30
|
+
StatusCode,
|
30
31
|
)
|
32
|
+
from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
|
31
33
|
from langroid.mytypes import Entity
|
32
34
|
from langroid.parsing.parse_json import extract_top_level_json
|
33
35
|
from langroid.utils.configuration import settings
|
34
|
-
from langroid.utils.constants import
|
36
|
+
from langroid.utils.constants import (
|
37
|
+
DONE,
|
38
|
+
NO_ANSWER,
|
39
|
+
PASS,
|
40
|
+
PASS_TO,
|
41
|
+
SEND_TO,
|
42
|
+
USER_QUIT_STRINGS,
|
43
|
+
)
|
35
44
|
from langroid.utils.logging import RichFileLogger, setup_file_logger
|
36
45
|
|
37
46
|
logger = logging.getLogger(__name__)
|
@@ -73,6 +82,9 @@ class Task:
|
|
73
82
|
the value of `result()`, which is the final result of the task.
|
74
83
|
"""
|
75
84
|
|
85
|
+
# class variable called `cache` that is a RedisCache object
|
86
|
+
cache: RedisCache = RedisCache(RedisCacheConfig(fake=False))
|
87
|
+
|
76
88
|
def __init__(
|
77
89
|
self,
|
78
90
|
agent: Optional[Agent] = None,
|
@@ -141,7 +153,6 @@ class Task:
|
|
141
153
|
"""
|
142
154
|
if agent is None:
|
143
155
|
agent = ChatAgent()
|
144
|
-
|
145
156
|
self.callbacks = SimpleNamespace(
|
146
157
|
show_subtask_response=noop_fn,
|
147
158
|
set_parent_agent=noop_fn,
|
@@ -172,6 +183,7 @@ class Task:
|
|
172
183
|
agent.set_user_message(user_message)
|
173
184
|
self.max_cost: float = 0
|
174
185
|
self.max_tokens: int = 0
|
186
|
+
self.session_id: str = ""
|
175
187
|
self.logger: None | RichFileLogger = None
|
176
188
|
self.tsv_logger: None | logging.Logger = None
|
177
189
|
self.color_log: bool = False if settings.notebook else True
|
@@ -285,6 +297,54 @@ class Task:
|
|
285
297
|
def __str__(self) -> str:
|
286
298
|
return f"{self.name}"
|
287
299
|
|
300
|
+
def _cache_session_store(self, key: str, value: str) -> None:
|
301
|
+
"""
|
302
|
+
Cache a key-value pair for the current session.
|
303
|
+
E.g. key = "kill", value = "1"
|
304
|
+
"""
|
305
|
+
try:
|
306
|
+
self.cache.store(f"{self.session_id}:{key}", value)
|
307
|
+
except Exception as e:
|
308
|
+
logging.error(f"Error in Task._cache_session_store: {e}")
|
309
|
+
|
310
|
+
def _cache_session_lookup(self, key: str) -> Dict[str, Any] | str | None:
|
311
|
+
"""
|
312
|
+
Retrieve a value from the cache for the current session.
|
313
|
+
"""
|
314
|
+
session_id_key = f"{self.session_id}:{key}"
|
315
|
+
try:
|
316
|
+
cached_val = self.cache.retrieve(session_id_key)
|
317
|
+
except Exception as e:
|
318
|
+
logging.error(f"Error in Task._cache_session_lookup: {e}")
|
319
|
+
return None
|
320
|
+
return cached_val
|
321
|
+
|
322
|
+
def _is_kill(self) -> bool:
|
323
|
+
"""
|
324
|
+
Check if the current session is killed.
|
325
|
+
"""
|
326
|
+
return self._cache_session_lookup("kill") == "1"
|
327
|
+
|
328
|
+
def _set_alive(self) -> None:
|
329
|
+
"""
|
330
|
+
Initialize the kill status of the current session.
|
331
|
+
"""
|
332
|
+
self._cache_session_store("kill", "0")
|
333
|
+
|
334
|
+
@classmethod
|
335
|
+
def kill_session(cls, session_id: str = "") -> None:
|
336
|
+
"""
|
337
|
+
Kill the session with the given session_id.
|
338
|
+
"""
|
339
|
+
session_id_kill_key = f"{session_id}:kill"
|
340
|
+
cls.cache.store(session_id_kill_key, "1")
|
341
|
+
|
342
|
+
def kill(self) -> None:
|
343
|
+
"""
|
344
|
+
Kill the task run associated with the current session.
|
345
|
+
"""
|
346
|
+
self._cache_session_store("kill", "1")
|
347
|
+
|
288
348
|
@property
|
289
349
|
def _level(self) -> int:
|
290
350
|
if self.caller is None:
|
@@ -378,6 +438,7 @@ class Task:
|
|
378
438
|
caller: None | Task = None,
|
379
439
|
max_cost: float = 0,
|
380
440
|
max_tokens: int = 0,
|
441
|
+
session_id: str = "",
|
381
442
|
) -> Optional[ChatDocument]:
|
382
443
|
"""Synchronous version of `run_async()`.
|
383
444
|
See `run_async()` for details."""
|
@@ -385,6 +446,9 @@ class Task:
|
|
385
446
|
self.n_stalled_steps = 0
|
386
447
|
self.max_cost = max_cost
|
387
448
|
self.max_tokens = max_tokens
|
449
|
+
self.session_id = session_id
|
450
|
+
self._set_alive()
|
451
|
+
|
388
452
|
assert (
|
389
453
|
msg is None or isinstance(msg, str) or isinstance(msg, ChatDocument)
|
390
454
|
), f"msg arg in Task.run() must be None, str, or ChatDocument, not {type(msg)}"
|
@@ -406,15 +470,19 @@ class Task:
|
|
406
470
|
i = 0
|
407
471
|
while True:
|
408
472
|
self.step()
|
409
|
-
|
473
|
+
done, status = self.done()
|
474
|
+
if done:
|
410
475
|
if self._level == 0 and not settings.quiet:
|
411
476
|
print("[magenta]Bye, hope this was useful!")
|
412
477
|
break
|
413
478
|
i += 1
|
414
479
|
if turns > 0 and i >= turns:
|
480
|
+
status = StatusCode.MAX_TURNS
|
415
481
|
break
|
416
482
|
|
417
483
|
final_result = self.result()
|
484
|
+
if final_result is not None:
|
485
|
+
final_result.metadata.status = status
|
418
486
|
self._post_run_loop()
|
419
487
|
return final_result
|
420
488
|
|
@@ -425,6 +493,7 @@ class Task:
|
|
425
493
|
caller: None | Task = None,
|
426
494
|
max_cost: float = 0,
|
427
495
|
max_tokens: int = 0,
|
496
|
+
session_id: str = "",
|
428
497
|
) -> Optional[ChatDocument]:
|
429
498
|
"""
|
430
499
|
Loop over `step()` until task is considered done or `turns` is reached.
|
@@ -443,6 +512,7 @@ class Task:
|
|
443
512
|
caller (Task|None): the calling task, if any
|
444
513
|
max_cost (float): max cost allowed for the task (default 0 -> no limit)
|
445
514
|
max_tokens (int): max tokens allowed for the task (default 0 -> no limit)
|
515
|
+
session_id (str): session id for the task
|
446
516
|
|
447
517
|
Returns:
|
448
518
|
Optional[ChatDocument]: valid result of the task.
|
@@ -456,6 +526,9 @@ class Task:
|
|
456
526
|
self.n_stalled_steps = 0
|
457
527
|
self.max_cost = max_cost
|
458
528
|
self.max_tokens = max_tokens
|
529
|
+
self.session_id = session_id
|
530
|
+
self._set_alive()
|
531
|
+
|
459
532
|
if (
|
460
533
|
isinstance(msg, ChatDocument)
|
461
534
|
and msg.metadata.recipient != ""
|
@@ -473,15 +546,19 @@ class Task:
|
|
473
546
|
i = 0
|
474
547
|
while True:
|
475
548
|
await self.step_async()
|
476
|
-
|
549
|
+
done, status = self.done()
|
550
|
+
if done:
|
477
551
|
if self._level == 0 and not settings.quiet:
|
478
552
|
print("[magenta]Bye, hope this was useful!")
|
479
553
|
break
|
480
554
|
i += 1
|
481
555
|
if turns > 0 and i >= turns:
|
556
|
+
status = StatusCode.MAX_TURNS
|
482
557
|
break
|
483
558
|
|
484
559
|
final_result = self.result()
|
560
|
+
if final_result is not None:
|
561
|
+
final_result.metadata.status = status
|
485
562
|
self._post_run_loop()
|
486
563
|
return final_result
|
487
564
|
|
@@ -942,6 +1019,7 @@ class Task:
|
|
942
1019
|
recipient = result_msg.metadata.recipient if result_msg else None
|
943
1020
|
responder = result_msg.metadata.parent_responder if result_msg else None
|
944
1021
|
tool_ids = result_msg.metadata.tool_ids if result_msg else []
|
1022
|
+
status = result_msg.metadata.status if result_msg else None
|
945
1023
|
|
946
1024
|
# regardless of which entity actually produced the result,
|
947
1025
|
# when we return the result, we set entity to USER
|
@@ -954,6 +1032,7 @@ class Task:
|
|
954
1032
|
source=Entity.USER,
|
955
1033
|
sender=Entity.USER,
|
956
1034
|
block=block,
|
1035
|
+
status=status,
|
957
1036
|
parent_responder=responder,
|
958
1037
|
sender_name=self.name,
|
959
1038
|
recipient=recipient,
|
@@ -1036,7 +1115,7 @@ class Task:
|
|
1036
1115
|
|
1037
1116
|
def done(
|
1038
1117
|
self, result: ChatDocument | None = None, r: Responder | None = None
|
1039
|
-
) -> bool:
|
1118
|
+
) -> Tuple[bool, StatusCode]:
|
1040
1119
|
"""
|
1041
1120
|
Check if task is done. This is the default behavior.
|
1042
1121
|
Derived classes can override this.
|
@@ -1046,26 +1125,29 @@ class Task:
|
|
1046
1125
|
Not used here, but could be used by derived classes.
|
1047
1126
|
Returns:
|
1048
1127
|
bool: True if task is done, False otherwise
|
1128
|
+
StatusCode: status code indicating why task is done
|
1049
1129
|
"""
|
1130
|
+
if self._is_kill():
|
1131
|
+
return (True, StatusCode.KILL)
|
1050
1132
|
result = result or self.pending_message
|
1051
1133
|
user_quit = (
|
1052
1134
|
result is not None
|
1053
|
-
and result.content in
|
1135
|
+
and result.content in USER_QUIT_STRINGS
|
1054
1136
|
and result.metadata.sender == Entity.USER
|
1055
1137
|
)
|
1056
1138
|
if self._level == 0 and self.only_user_quits_root:
|
1057
1139
|
# for top-level task, only user can quit out
|
1058
|
-
return user_quit
|
1140
|
+
return (user_quit, StatusCode.USER_QUIT if user_quit else StatusCode.OK)
|
1059
1141
|
|
1060
1142
|
if self.is_done:
|
1061
|
-
return True
|
1143
|
+
return (True, StatusCode.DONE)
|
1062
1144
|
|
1063
1145
|
if self.n_stalled_steps >= self.max_stalled_steps:
|
1064
1146
|
# we are stuck, so bail to avoid infinite loop
|
1065
1147
|
logger.warning(
|
1066
1148
|
f"Task {self.name} stuck for {self.max_stalled_steps} steps; exiting."
|
1067
1149
|
)
|
1068
|
-
return True
|
1150
|
+
return (True, StatusCode.STALLED)
|
1069
1151
|
|
1070
1152
|
if self.max_cost > 0 and self.agent.llm is not None:
|
1071
1153
|
try:
|
@@ -1073,7 +1155,7 @@ class Task:
|
|
1073
1155
|
logger.warning(
|
1074
1156
|
f"Task {self.name} cost exceeded {self.max_cost}; exiting."
|
1075
1157
|
)
|
1076
|
-
return True
|
1158
|
+
return (True, StatusCode.MAX_COST)
|
1077
1159
|
except Exception:
|
1078
1160
|
pass
|
1079
1161
|
|
@@ -1083,10 +1165,10 @@ class Task:
|
|
1083
1165
|
logger.warning(
|
1084
1166
|
f"Task {self.name} uses > {self.max_tokens} tokens; exiting."
|
1085
1167
|
)
|
1086
|
-
return True
|
1168
|
+
return (True, StatusCode.MAX_TOKENS)
|
1087
1169
|
except Exception:
|
1088
1170
|
pass
|
1089
|
-
|
1171
|
+
final = (
|
1090
1172
|
# no valid response from any entity/agent in current turn
|
1091
1173
|
result is None
|
1092
1174
|
# An entity decided task is done
|
@@ -1103,6 +1185,7 @@ class Task:
|
|
1103
1185
|
# )
|
1104
1186
|
or user_quit
|
1105
1187
|
)
|
1188
|
+
return (final, StatusCode.OK)
|
1106
1189
|
|
1107
1190
|
def valid(
|
1108
1191
|
self,
|
@@ -1120,7 +1203,7 @@ class Task:
|
|
1120
1203
|
|
1121
1204
|
# if task would be considered done given responder r's `result`,
|
1122
1205
|
# then consider the result valid.
|
1123
|
-
if result is not None and self.done(result, r):
|
1206
|
+
if result is not None and self.done(result, r)[0]:
|
1124
1207
|
return True
|
1125
1208
|
return (
|
1126
1209
|
result is not None
|
@@ -109,7 +109,7 @@ class RedisCache(CacheDB):
|
|
109
109
|
key (str): The key to retrieve the value for.
|
110
110
|
|
111
111
|
Returns:
|
112
|
-
dict: The value associated with the key.
|
112
|
+
dict|str|None: The value associated with the key.
|
113
113
|
"""
|
114
114
|
with self.redis_client() as client: # type: ignore
|
115
115
|
try:
|
@@ -21,6 +21,7 @@ from typing import (
|
|
21
21
|
)
|
22
22
|
|
23
23
|
import openai
|
24
|
+
from groq import AsyncGroq, Groq
|
24
25
|
from httpx import Timeout
|
25
26
|
from openai import AsyncOpenAI, OpenAI
|
26
27
|
from pydantic import BaseModel
|
@@ -347,6 +348,9 @@ class OpenAIGPT(LanguageModel):
|
|
347
348
|
Class for OpenAI LLMs
|
348
349
|
"""
|
349
350
|
|
351
|
+
client: OpenAI | Groq
|
352
|
+
async_client: AsyncOpenAI | AsyncGroq
|
353
|
+
|
350
354
|
def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
|
351
355
|
"""
|
352
356
|
Args:
|
@@ -448,18 +452,31 @@ class OpenAIGPT(LanguageModel):
|
|
448
452
|
self.api_key = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
|
449
453
|
else:
|
450
454
|
self.api_key = DUMMY_API_KEY
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
455
|
+
|
456
|
+
self.is_groq = self.config.chat_model.startswith("groq/")
|
457
|
+
|
458
|
+
if self.is_groq:
|
459
|
+
self.config.chat_model = self.config.chat_model.replace("groq/", "")
|
460
|
+
self.api_key = os.getenv("GROQ_API_KEY", DUMMY_API_KEY)
|
461
|
+
self.client = Groq(
|
462
|
+
api_key=self.api_key,
|
463
|
+
)
|
464
|
+
self.async_client = AsyncGroq(
|
465
|
+
api_key=self.api_key,
|
466
|
+
)
|
467
|
+
else:
|
468
|
+
self.client = OpenAI(
|
469
|
+
api_key=self.api_key,
|
470
|
+
base_url=self.api_base,
|
471
|
+
organization=self.config.organization,
|
472
|
+
timeout=Timeout(self.config.timeout),
|
473
|
+
)
|
474
|
+
self.async_client = AsyncOpenAI(
|
475
|
+
api_key=self.api_key,
|
476
|
+
organization=self.config.organization,
|
477
|
+
base_url=self.api_base,
|
478
|
+
timeout=Timeout(self.config.timeout),
|
479
|
+
)
|
463
480
|
|
464
481
|
self.cache: MomentoCache | RedisCache
|
465
482
|
if settings.cache_type == "momento":
|
@@ -855,6 +872,9 @@ class OpenAIGPT(LanguageModel):
|
|
855
872
|
if self.config.use_chat_for_completion:
|
856
873
|
return self.chat(messages=prompt, max_tokens=max_tokens)
|
857
874
|
|
875
|
+
if self.is_groq:
|
876
|
+
raise ValueError("Groq does not support pure completions")
|
877
|
+
|
858
878
|
if settings.debug:
|
859
879
|
print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
|
860
880
|
|
@@ -869,6 +889,7 @@ class OpenAIGPT(LanguageModel):
|
|
869
889
|
else:
|
870
890
|
if self.config.litellm:
|
871
891
|
from litellm import completion as litellm_completion
|
892
|
+
assert isinstance(self.client, OpenAI)
|
872
893
|
completion_call = (
|
873
894
|
litellm_completion
|
874
895
|
if self.config.litellm
|
@@ -929,6 +950,9 @@ class OpenAIGPT(LanguageModel):
|
|
929
950
|
if self.config.use_chat_for_completion:
|
930
951
|
return await self.achat(messages=prompt, max_tokens=max_tokens)
|
931
952
|
|
953
|
+
if self.is_groq:
|
954
|
+
raise ValueError("Groq does not support pure completions")
|
955
|
+
|
932
956
|
if settings.debug:
|
933
957
|
print(f"[grey37]PROMPT: {escape(prompt)}[/grey37]")
|
934
958
|
|
@@ -948,6 +972,7 @@ class OpenAIGPT(LanguageModel):
|
|
948
972
|
from litellm import acompletion as litellm_acompletion
|
949
973
|
# TODO this may not work: text_completion is not async,
|
950
974
|
# and we didn't find an async version in litellm
|
975
|
+
assert isinstance(self.async_client, AsyncOpenAI)
|
951
976
|
acompletion_call = (
|
952
977
|
litellm_acompletion
|
953
978
|
if self.config.litellm
|
langroid/utils/constants.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: langroid
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.236
|
4
4
|
Summary: Harness LLMs with Multi-Agent Programming
|
5
5
|
License: MIT
|
6
6
|
Author: Prasad Chalasani
|
@@ -35,6 +35,7 @@ Requires-Dist: fakeredis (>=2.12.1,<3.0.0)
|
|
35
35
|
Requires-Dist: fire (>=0.5.0,<0.6.0)
|
36
36
|
Requires-Dist: flake8 (>=6.0.0,<7.0.0)
|
37
37
|
Requires-Dist: google-api-python-client (>=2.95.0,<3.0.0)
|
38
|
+
Requires-Dist: groq (>=0.5.0,<0.6.0)
|
38
39
|
Requires-Dist: grpcio (>=1.62.1,<2.0.0)
|
39
40
|
Requires-Dist: halo (>=0.0.31,<0.0.32)
|
40
41
|
Requires-Dist: huggingface-hub (>=0.21.2,<0.22.0) ; extra == "transformers"
|
@@ -230,6 +231,19 @@ teacher_task.run()
|
|
230
231
|
<details>
|
231
232
|
<summary> <b>Click to expand</b></summary>
|
232
233
|
|
234
|
+
- **Apr 2024:**
|
235
|
+
- **0.1.236**: Support for open LLMs hosted on Groq, e.g. specify
|
236
|
+
`chat_model="groq/llama3-8b-8192"`.
|
237
|
+
See [tutorial](https://langroid.github.io/langroid/tutorials/local-llm-setup/).
|
238
|
+
- **0.1.235**: `Task.run(), Task.run_async(), run_batch_tasks` have `max_cost`
|
239
|
+
and `max_tokens` params to exit when tokens or cost exceed a limit. The result
|
240
|
+
`ChatDocument.metadata` now includes a `status` field which is a code indicating a
|
241
|
+
task completion reason code. Also `task.run()` etc can be invoked with an explicit
|
242
|
+
`session_id` field which is used as a key to look up various settings in Redis cache.
|
243
|
+
Currently only used to look up "kill status" - this allows killing a running task, either by `task.kill()`
|
244
|
+
or by the classmethod `Task.kill_session(session_id)`.
|
245
|
+
For example usage, see the `test_task_kill` in [tests/main/test_task.py](https://github.com/langroid/langroid/blob/main/tests/main/test_task.py)
|
246
|
+
|
233
247
|
- **Mar 2024:**
|
234
248
|
- **0.1.216:** Improvements to allow concurrent runs of `DocChatAgent`, see the
|
235
249
|
[`test_doc_chat_agent.py`](https://github.com/langroid/langroid/blob/main/tests/main/test_doc_chat_agent.py)
|
@@ -1,14 +1,14 @@
|
|
1
|
-
langroid/__init__.py,sha256=
|
1
|
+
langroid/__init__.py,sha256=zsYpGiAUsvyzZzjm964NUamsJImrXSJPVGz9a2jE_uY,1679
|
2
2
|
langroid/agent/__init__.py,sha256=_D8dxnfdr92ch1CIrUkKjrB5HVvsQdn62b1Fb2kBxV8,785
|
3
3
|
langroid/agent/base.py,sha256=jyGFmojrFuOy81lUkNsJlR6mLIOY6kOD20P9dhEcEuw,35059
|
4
|
-
langroid/agent/batch.py,sha256=
|
4
|
+
langroid/agent/batch.py,sha256=feRA_yRG768ElOQjrKEefcRv6Aefd_yY7qktuYUQDwc,10040
|
5
5
|
langroid/agent/callbacks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
langroid/agent/callbacks/chainlit.py,sha256=aYuJ8M4VDHr5oymoXL2bpThM7p6P9L45fgJf3MLdkWo,20997
|
7
7
|
langroid/agent/chat_agent.py,sha256=X5uVMm9qdw3j-FRf4hbN8k8ByaSdtQCTuU8olKE0sbs,38750
|
8
|
-
langroid/agent/chat_document.py,sha256=
|
8
|
+
langroid/agent/chat_document.py,sha256=NGr5FEWasPUQZ7cJnqrkVYYTi5fOqplSoCU-z5tTONA,8422
|
9
9
|
langroid/agent/helpers.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
langroid/agent/junk,sha256=LxfuuW7Cijsg0szAzT81OjWWv1PMNI-6w_-DspVIO2s,339
|
11
|
-
langroid/agent/openai_assistant.py,sha256=
|
11
|
+
langroid/agent/openai_assistant.py,sha256=kIVDI4r-xGvplLU5s0nShPVHs6Jq-wOsfWE0kcMhAdQ,33056
|
12
12
|
langroid/agent/special/__init__.py,sha256=NG0JkB5y4K0bgnd9Q9UIvFExun3uTfVOWEVLVymff1M,1207
|
13
13
|
langroid/agent/special/doc_chat_agent.py,sha256=LwWNb_1s5n9rOk9OpOFPuuY1VnVX5DjzQmPwBanKRrM,53763
|
14
14
|
langroid/agent/special/lance_doc_chat_agent.py,sha256=USp0U3eTaJzwF_3bdqE7CedSLbaqAi2tm-VzygcyLaA,10175
|
@@ -31,8 +31,8 @@ langroid/agent/special/sql/utils/description_extractors.py,sha256=RZ2R3DmASxB1ij
|
|
31
31
|
langroid/agent/special/sql/utils/populate_metadata.py,sha256=x2OMKfmIBnJESBG3qKt6gvr3H3L4ZQcoxHfNdWfHjZs,2987
|
32
32
|
langroid/agent/special/sql/utils/system_message.py,sha256=qKLHkvQWRQodTtPLPxr1GSLUYUFASZU8x-ybV67cB68,1885
|
33
33
|
langroid/agent/special/sql/utils/tools.py,sha256=6uB2424SLtmapui9ggcEr0ZTiB6_dL1-JRGgN8RK9Js,1332
|
34
|
-
langroid/agent/special/table_chat_agent.py,sha256=
|
35
|
-
langroid/agent/task.py,sha256=
|
34
|
+
langroid/agent/special/table_chat_agent.py,sha256=coEvEWL9UJJSeDu8JcOxR4qCyzH7HuTdre7-3pMfGjo,8785
|
35
|
+
langroid/agent/task.py,sha256=b_d46txohISETxXJoWpmIX0hinvt1wjHbK08LZRBEz8,54020
|
36
36
|
langroid/agent/tool_message.py,sha256=2kPsQUwi3ZzINTUNj10huKnZLjLp5SXmefacTHx8QDc,8304
|
37
37
|
langroid/agent/tools/__init__.py,sha256=q-maq3k2BXhPAU99G0H6-j_ozoRvx15I1RFpPVicQIU,304
|
38
38
|
langroid/agent/tools/duckduckgo_search_tool.py,sha256=mLGhlgs6pwbYZIwrOs9shfh1dMBVT4DtkR29pYL3cCQ,1900
|
@@ -47,7 +47,7 @@ langroid/agent_config.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
47
47
|
langroid/cachedb/__init__.py,sha256=ygx42MS7fvh2UwRMjukTk3dWBkzv_rACebTBRYa_MkU,148
|
48
48
|
langroid/cachedb/base.py,sha256=tdIZmdDdDMW-wVkNQdi4vMQCHP718l9JM6cDhL6odf4,1229
|
49
49
|
langroid/cachedb/momento_cachedb.py,sha256=IbaYG7HgG-G18GlWsYVnLv0r2e2S48z6sl8OlJOGUfc,2998
|
50
|
-
langroid/cachedb/redis_cachedb.py,sha256=
|
50
|
+
langroid/cachedb/redis_cachedb.py,sha256=5WrwgareXGboZeaCLkJ8MarqRMrrXl4_8o8aDrdrOCE,4993
|
51
51
|
langroid/embedding_models/__init__.py,sha256=AJg2668ytmUyqYP0SGw-ZKz2ITi4YK7IAv2lfCjFfOg,714
|
52
52
|
langroid/embedding_models/base.py,sha256=xY9QF01ilsMvaNH4JMDvkZgXY59AeYR4VAykgNd6Flg,1818
|
53
53
|
langroid/embedding_models/clustering.py,sha256=tZWElUqXl9Etqla0FAa7og96iDKgjqWjucZR_Egtp-A,6684
|
@@ -62,7 +62,7 @@ langroid/language_models/azure_openai.py,sha256=ncRCbKooqLVOY-PWQUIo9C3yTuKEFbAw
|
|
62
62
|
langroid/language_models/base.py,sha256=B6dX43ZR65mIvjD95W4RcfpT-WpmiuEcstR3eMrr56Y,21029
|
63
63
|
langroid/language_models/config.py,sha256=5UF3DzO1a-Dfsc3vghE0XGq7g9t_xDsRCsuRiU4dgBg,366
|
64
64
|
langroid/language_models/openai_assistants.py,sha256=9K-DEAL2aSWHeXj2hwCo2RAlK9_1oCPtqX2u1wISCj8,36
|
65
|
-
langroid/language_models/openai_gpt.py,sha256=
|
65
|
+
langroid/language_models/openai_gpt.py,sha256=BOZt2lOFViN3ct-jvfELRKeUkUaBOGhGxO7F6JQNCNY,50257
|
66
66
|
langroid/language_models/prompt_formatter/__init__.py,sha256=9JXFF22QNMmbQV1q4nrIeQVTtA3Tx8tEZABLtLBdFyc,352
|
67
67
|
langroid/language_models/prompt_formatter/base.py,sha256=eDS1sgRNZVnoajwV_ZIha6cba5Dt8xjgzdRbPITwx3Q,1221
|
68
68
|
langroid/language_models/prompt_formatter/hf_formatter.py,sha256=TFL6ppmeQWnzr6CKQzRZFYY810zE1mr8DZnhw6i85ok,5217
|
@@ -98,7 +98,7 @@ langroid/utils/__init__.py,sha256=ARx5To4Hsv1K5QAzK4uUqdEoB_iq5HK797vae1AcMBI,30
|
|
98
98
|
langroid/utils/algorithms/__init__.py,sha256=WylYoZymA0fnzpB4vrsH_0n7WsoLhmuZq8qxsOCjUpM,41
|
99
99
|
langroid/utils/algorithms/graph.py,sha256=JbdpPnUOhw4-D6O7ou101JLA3xPCD0Lr3qaPoFCaRfo,2866
|
100
100
|
langroid/utils/configuration.py,sha256=TiDZrQVeEthMFA4QY_HTgQaDCJwS4I5S-aR_taOdc00,3201
|
101
|
-
langroid/utils/constants.py,sha256=
|
101
|
+
langroid/utils/constants.py,sha256=Y_8p7CyLF5b3xsEV5O3wuutLHQCtegsaxWgr8yNTlIE,563
|
102
102
|
langroid/utils/docker.py,sha256=kJQOLTgM0x9j9pgIIqp0dZNZCTvoUDhp6i8tYBq1Jr0,1105
|
103
103
|
langroid/utils/globals.py,sha256=VkTHhlqSz86oOPq65sjul0XU8I52UNaFC5vwybMQ74w,1343
|
104
104
|
langroid/utils/llms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -120,7 +120,7 @@ langroid/vector_store/meilisearch.py,sha256=d2huA9P-NoYRuAQ9ZeXJmMKr7ry8u90RUSR2
|
|
120
120
|
langroid/vector_store/momento.py,sha256=9cui31TTrILid2KIzUpBkN2Ey3g_CZWOQVdaFsA4Ors,10045
|
121
121
|
langroid/vector_store/qdrant_cloud.py,sha256=3im4Mip0QXLkR6wiqVsjV1QvhSElfxdFSuDKddBDQ-4,188
|
122
122
|
langroid/vector_store/qdrantdb.py,sha256=foKRxRv0BBony6S4Vt0Vav9Rn9HMxZvcIh1cE7nosFE,13524
|
123
|
-
langroid-0.1.
|
124
|
-
langroid-0.1.
|
125
|
-
langroid-0.1.
|
126
|
-
langroid-0.1.
|
123
|
+
langroid-0.1.236.dist-info/LICENSE,sha256=EgVbvA6VSYgUlvC3RvPKehSg7MFaxWDsFuzLOsPPfJg,1065
|
124
|
+
langroid-0.1.236.dist-info/METADATA,sha256=jZ9zU6bW0HHFIwFgeUvlDp4VrpPvoYsOk0S6nAbvHNw,48866
|
125
|
+
langroid-0.1.236.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
126
|
+
langroid-0.1.236.dist-info/RECORD,,
|
File without changes
|
File without changes
|