langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +70 -0
- langroid/agent/__init__.py +22 -0
- langroid/agent/base.py +120 -33
- langroid/agent/batch.py +134 -35
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +164 -100
- langroid/agent/chat_document.py +19 -2
- langroid/agent/openai_assistant.py +20 -10
- langroid/agent/special/__init__.py +33 -10
- langroid/agent/special/doc_chat_agent.py +521 -108
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +23 -7
- langroid/agent/special/retriever_agent.py +29 -174
- langroid/agent/special/sql/__init__.py +7 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +11 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +423 -114
- langroid/agent/tool_message.py +67 -10
- langroid/agent/tools/__init__.py +8 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +6 -24
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/cachedb/__init__.py +6 -0
- langroid/embedding_models/__init__.py +24 -0
- langroid/embedding_models/base.py +9 -1
- langroid/embedding_models/models.py +117 -17
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +22 -0
- langroid/language_models/azure_openai.py +47 -4
- langroid/language_models/base.py +26 -10
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_gpt.py +407 -121
- langroid/language_models/prompt_formatter/__init__.py +9 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +10 -9
- langroid/mytypes.py +10 -4
- langroid/parsing/__init__.py +33 -1
- langroid/parsing/document_parser.py +259 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +20 -7
- langroid/parsing/repo_loader.py +108 -46
- langroid/parsing/search.py +8 -0
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -13
- langroid/parsing/urls.py +18 -9
- langroid/parsing/utils.py +130 -9
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +7 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +10 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/configuration.py +0 -1
- langroid/utils/constants.py +4 -0
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +15 -2
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +446 -4
- langroid/utils/system.py +36 -1
- langroid/vector_store/__init__.py +34 -2
- langroid/vector_store/base.py +33 -2
- langroid/vector_store/chromadb.py +42 -13
- langroid/vector_store/lancedb.py +226 -60
- langroid/vector_store/meilisearch.py +7 -6
- langroid/vector_store/momento.py +3 -2
- langroid/vector_store/qdrantdb.py +82 -11
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
- langroid-0.1.219.dist-info/RECORD +127 -0
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.139.dist-info/RECORD +0 -103
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
- {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -0,0 +1,33 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3
|
+
# source: embeddings.proto
|
4
|
+
# Protobuf Python Version: 4.25.1
|
5
|
+
"""Generated protocol buffer code."""
|
6
|
+
from google.protobuf import descriptor as _descriptor
|
7
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8
|
+
from google.protobuf import symbol_database as _symbol_database
|
9
|
+
from google.protobuf.internal import builder as _builder
|
10
|
+
|
11
|
+
# @@protoc_insertion_point(imports)
|
12
|
+
|
13
|
+
_sym_db = _symbol_database.Default()
|
14
|
+
|
15
|
+
|
16
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
17
|
+
b'\n\x10\x65mbeddings.proto"K\n\x10\x45mbeddingRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x02 \x01(\x05\x12\x0f\n\x07strings\x18\x03 \x03(\t"%\n\x0b\x42\x61tchEmbeds\x12\x16\n\x06\x65mbeds\x18\x01 \x03(\x0b\x32\x06.Embed"\x16\n\x05\x45mbed\x12\r\n\x05\x65mbed\x18\x01 \x03(\x02\x32\x37\n\tEmbedding\x12*\n\x05\x45mbed\x12\x11.EmbeddingRequest\x1a\x0c.BatchEmbeds"\x00\x62\x06proto3'
|
18
|
+
)
|
19
|
+
|
20
|
+
_globals = globals()
|
21
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
22
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embeddings_pb2", _globals)
|
23
|
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
24
|
+
DESCRIPTOR._options = None
|
25
|
+
_globals["_EMBEDDINGREQUEST"]._serialized_start = 20
|
26
|
+
_globals["_EMBEDDINGREQUEST"]._serialized_end = 95
|
27
|
+
_globals["_BATCHEMBEDS"]._serialized_start = 97
|
28
|
+
_globals["_BATCHEMBEDS"]._serialized_end = 134
|
29
|
+
_globals["_EMBED"]._serialized_start = 136
|
30
|
+
_globals["_EMBED"]._serialized_end = 158
|
31
|
+
_globals["_EMBEDDING"]._serialized_start = 160
|
32
|
+
_globals["_EMBEDDING"]._serialized_end = 215
|
33
|
+
# @@protoc_insertion_point(module_scope)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
from typing import (
|
2
|
+
ClassVar as _ClassVar,
|
3
|
+
)
|
4
|
+
from typing import (
|
5
|
+
Iterable as _Iterable,
|
6
|
+
)
|
7
|
+
from typing import (
|
8
|
+
Mapping as _Mapping,
|
9
|
+
)
|
10
|
+
from typing import (
|
11
|
+
Optional as _Optional,
|
12
|
+
)
|
13
|
+
from typing import (
|
14
|
+
Union as _Union,
|
15
|
+
)
|
16
|
+
|
17
|
+
from google.protobuf import descriptor as _descriptor
|
18
|
+
from google.protobuf import message as _message
|
19
|
+
from google.protobuf.internal import containers as _containers
|
20
|
+
|
21
|
+
DESCRIPTOR: _descriptor.FileDescriptor
|
22
|
+
|
23
|
+
class EmbeddingRequest(_message.Message):
|
24
|
+
__slots__ = ("model_name", "batch_size", "strings")
|
25
|
+
MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
|
26
|
+
BATCH_SIZE_FIELD_NUMBER: _ClassVar[int]
|
27
|
+
STRINGS_FIELD_NUMBER: _ClassVar[int]
|
28
|
+
model_name: str
|
29
|
+
batch_size: int
|
30
|
+
strings: _containers.RepeatedScalarFieldContainer[str]
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
model_name: _Optional[str] = ...,
|
34
|
+
batch_size: _Optional[int] = ...,
|
35
|
+
strings: _Optional[_Iterable[str]] = ...,
|
36
|
+
) -> None: ...
|
37
|
+
|
38
|
+
class BatchEmbeds(_message.Message):
|
39
|
+
__slots__ = ("embeds",)
|
40
|
+
EMBEDS_FIELD_NUMBER: _ClassVar[int]
|
41
|
+
embeds: _containers.RepeatedCompositeFieldContainer[Embed]
|
42
|
+
def __init__(
|
43
|
+
self, embeds: _Optional[_Iterable[_Union[Embed, _Mapping]]] = ...
|
44
|
+
) -> None: ...
|
45
|
+
|
46
|
+
class Embed(_message.Message):
|
47
|
+
__slots__ = ("embed",)
|
48
|
+
EMBED_FIELD_NUMBER: _ClassVar[int]
|
49
|
+
embed: _containers.RepeatedScalarFieldContainer[float]
|
50
|
+
def __init__(self, embed: _Optional[_Iterable[float]] = ...) -> None: ...
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
2
|
+
"""Client and server classes corresponding to protobuf-defined services."""
|
3
|
+
import grpc
|
4
|
+
|
5
|
+
import langroid.embedding_models.protoc.embeddings_pb2 as embeddings__pb2
|
6
|
+
|
7
|
+
|
8
|
+
class EmbeddingStub(object):
|
9
|
+
"""Missing associated documentation comment in .proto file."""
|
10
|
+
|
11
|
+
def __init__(self, channel):
|
12
|
+
"""Constructor.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
channel: A grpc.Channel.
|
16
|
+
"""
|
17
|
+
self.Embed = channel.unary_unary(
|
18
|
+
"/Embedding/Embed",
|
19
|
+
request_serializer=embeddings__pb2.EmbeddingRequest.SerializeToString,
|
20
|
+
response_deserializer=embeddings__pb2.BatchEmbeds.FromString,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class EmbeddingServicer(object):
|
25
|
+
"""Missing associated documentation comment in .proto file."""
|
26
|
+
|
27
|
+
def Embed(self, request, context):
|
28
|
+
"""Missing associated documentation comment in .proto file."""
|
29
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
30
|
+
context.set_details("Method not implemented!")
|
31
|
+
raise NotImplementedError("Method not implemented!")
|
32
|
+
|
33
|
+
|
34
|
+
def add_EmbeddingServicer_to_server(servicer, server):
|
35
|
+
rpc_method_handlers = {
|
36
|
+
"Embed": grpc.unary_unary_rpc_method_handler(
|
37
|
+
servicer.Embed,
|
38
|
+
request_deserializer=embeddings__pb2.EmbeddingRequest.FromString,
|
39
|
+
response_serializer=embeddings__pb2.BatchEmbeds.SerializeToString,
|
40
|
+
),
|
41
|
+
}
|
42
|
+
generic_handler = grpc.method_handlers_generic_handler(
|
43
|
+
"Embedding", rpc_method_handlers
|
44
|
+
)
|
45
|
+
server.add_generic_rpc_handlers((generic_handler,))
|
46
|
+
|
47
|
+
|
48
|
+
# This class is part of an EXPERIMENTAL API.
|
49
|
+
class Embedding(object):
|
50
|
+
"""Missing associated documentation comment in .proto file."""
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def Embed(
|
54
|
+
request,
|
55
|
+
target,
|
56
|
+
options=(),
|
57
|
+
channel_credentials=None,
|
58
|
+
call_credentials=None,
|
59
|
+
insecure=False,
|
60
|
+
compression=None,
|
61
|
+
wait_for_ready=None,
|
62
|
+
timeout=None,
|
63
|
+
metadata=None,
|
64
|
+
):
|
65
|
+
return grpc.experimental.unary_unary(
|
66
|
+
request,
|
67
|
+
target,
|
68
|
+
"/Embedding/Embed",
|
69
|
+
embeddings__pb2.EmbeddingRequest.SerializeToString,
|
70
|
+
embeddings__pb2.BatchEmbeds.FromString,
|
71
|
+
options,
|
72
|
+
channel_credentials,
|
73
|
+
insecure,
|
74
|
+
call_credentials,
|
75
|
+
compression,
|
76
|
+
wait_for_ready,
|
77
|
+
timeout,
|
78
|
+
metadata,
|
79
|
+
)
|
@@ -0,0 +1,153 @@
|
|
1
|
+
"""
|
2
|
+
If run as a script, starts an RPC server which handles remote
|
3
|
+
embedding requests:
|
4
|
+
|
5
|
+
For example:
|
6
|
+
python3 -m langroid.embedding_models.remote_embeds --port `port`
|
7
|
+
|
8
|
+
where `port` is the port at which the service is exposed. Currently,
|
9
|
+
supports insecure connections only, and this should NOT be exposed to
|
10
|
+
the internet.
|
11
|
+
"""
|
12
|
+
|
13
|
+
import atexit
|
14
|
+
import subprocess
|
15
|
+
import time
|
16
|
+
from typing import Callable, Optional
|
17
|
+
|
18
|
+
import grpc
|
19
|
+
from fire import Fire
|
20
|
+
|
21
|
+
import langroid.embedding_models.models as em
|
22
|
+
import langroid.embedding_models.protoc.embeddings_pb2 as embeddings_pb
|
23
|
+
import langroid.embedding_models.protoc.embeddings_pb2_grpc as embeddings_grpc
|
24
|
+
from langroid.mytypes import Embeddings
|
25
|
+
|
26
|
+
|
27
|
+
class RemoteEmbeddingRPCs(embeddings_grpc.EmbeddingServicer):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
model_name: str,
|
31
|
+
batch_size: int,
|
32
|
+
data_parallel: bool,
|
33
|
+
device: Optional[str],
|
34
|
+
devices: Optional[list[str]],
|
35
|
+
):
|
36
|
+
super().__init__()
|
37
|
+
self.embedding_fn = em.SentenceTransformerEmbeddings(
|
38
|
+
em.SentenceTransformerEmbeddingsConfig(
|
39
|
+
model_name=model_name,
|
40
|
+
batch_size=batch_size,
|
41
|
+
data_parallel=data_parallel,
|
42
|
+
device=device,
|
43
|
+
devices=devices,
|
44
|
+
)
|
45
|
+
).embedding_fn()
|
46
|
+
|
47
|
+
def Embed(
|
48
|
+
self, request: embeddings_pb.EmbeddingRequest, _: grpc.RpcContext
|
49
|
+
) -> embeddings_pb.BatchEmbeds:
|
50
|
+
embeds = self.embedding_fn(list(request.strings))
|
51
|
+
|
52
|
+
embeds_pb = [embeddings_pb.Embed(embed=e) for e in embeds]
|
53
|
+
|
54
|
+
return embeddings_pb.BatchEmbeds(embeds=embeds_pb)
|
55
|
+
|
56
|
+
|
57
|
+
class RemoteEmbeddingsConfig(em.SentenceTransformerEmbeddingsConfig):
|
58
|
+
api_base: str = "localhost"
|
59
|
+
port: int = 50052
|
60
|
+
# The below are used only when waiting for server creation
|
61
|
+
poll_delay: float = 0.01
|
62
|
+
max_retries: int = 1000
|
63
|
+
|
64
|
+
|
65
|
+
class RemoteEmbeddings(em.SentenceTransformerEmbeddings):
|
66
|
+
def __init__(self, config: RemoteEmbeddingsConfig = RemoteEmbeddingsConfig()):
|
67
|
+
super().__init__(config)
|
68
|
+
self.config: RemoteEmbeddingsConfig = config
|
69
|
+
self.have_started_server: bool = False
|
70
|
+
|
71
|
+
def embedding_fn(self) -> Callable[[list[str]], Embeddings]:
|
72
|
+
def fn(texts: list[str]) -> Embeddings:
|
73
|
+
url = f"{self.config.api_base}:{self.config.port}"
|
74
|
+
with grpc.insecure_channel(url) as channel:
|
75
|
+
stub = embeddings_grpc.EmbeddingStub(channel) # type: ignore
|
76
|
+
response = stub.Embed(
|
77
|
+
embeddings_pb.EmbeddingRequest(
|
78
|
+
strings=texts,
|
79
|
+
)
|
80
|
+
)
|
81
|
+
|
82
|
+
return [list(emb.embed) for emb in response.embeds]
|
83
|
+
|
84
|
+
def with_handling(texts: list[str]) -> Embeddings:
|
85
|
+
# In local mode, start the server if it has not already
|
86
|
+
# been started
|
87
|
+
if self.config.api_base == "localhost" and not self.have_started_server:
|
88
|
+
try:
|
89
|
+
return fn(texts)
|
90
|
+
# Occurs when the server hasn't been started
|
91
|
+
except grpc.RpcError:
|
92
|
+
self.have_started_server = True
|
93
|
+
# Start the server
|
94
|
+
proc = subprocess.Popen(
|
95
|
+
[
|
96
|
+
"python3",
|
97
|
+
__file__,
|
98
|
+
"--bind_address_base",
|
99
|
+
self.config.api_base,
|
100
|
+
"--port",
|
101
|
+
str(self.config.port),
|
102
|
+
"--batch_size",
|
103
|
+
str(self.config.batch_size),
|
104
|
+
"--model_name",
|
105
|
+
self.config.model_name,
|
106
|
+
],
|
107
|
+
)
|
108
|
+
|
109
|
+
atexit.register(lambda: proc.terminate())
|
110
|
+
|
111
|
+
for _ in range(self.config.max_retries - 1):
|
112
|
+
try:
|
113
|
+
return fn(texts)
|
114
|
+
except grpc.RpcError:
|
115
|
+
time.sleep(self.config.poll_delay)
|
116
|
+
|
117
|
+
# The remote is not local or we have exhausted retries
|
118
|
+
# We should now raise an error if the server is not accessible
|
119
|
+
return fn(texts)
|
120
|
+
|
121
|
+
return with_handling
|
122
|
+
|
123
|
+
|
124
|
+
async def serve(
|
125
|
+
bind_address_base: str = "localhost",
|
126
|
+
port: int = 50052,
|
127
|
+
batch_size: int = 512,
|
128
|
+
data_parallel: bool = False,
|
129
|
+
device: Optional[str] = None,
|
130
|
+
devices: Optional[list[str]] = None,
|
131
|
+
model_name: str = "BAAI/bge-large-en-v1.5",
|
132
|
+
) -> None:
|
133
|
+
"""Starts the RPC server."""
|
134
|
+
server = grpc.aio.server()
|
135
|
+
embeddings_grpc.add_EmbeddingServicer_to_server(
|
136
|
+
RemoteEmbeddingRPCs(
|
137
|
+
model_name=model_name,
|
138
|
+
batch_size=batch_size,
|
139
|
+
data_parallel=data_parallel,
|
140
|
+
device=device,
|
141
|
+
devices=devices,
|
142
|
+
),
|
143
|
+
server,
|
144
|
+
) # type: ignore
|
145
|
+
url = f"{bind_address_base}:{port}"
|
146
|
+
server.add_insecure_port(url)
|
147
|
+
await server.start()
|
148
|
+
print(f"Embedding server started, listening on {url}")
|
149
|
+
await server.wait_for_termination()
|
150
|
+
|
151
|
+
|
152
|
+
if __name__ == "__main__":
|
153
|
+
Fire(serve)
|
@@ -21,3 +21,25 @@ from .openai_gpt import (
|
|
21
21
|
OpenAIGPT,
|
22
22
|
)
|
23
23
|
from .azure_openai import AzureConfig, AzureGPT
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"utils",
|
27
|
+
"config",
|
28
|
+
"base",
|
29
|
+
"openai_gpt",
|
30
|
+
"azure_openai",
|
31
|
+
"prompt_formatter",
|
32
|
+
"LLMConfig",
|
33
|
+
"LLMMessage",
|
34
|
+
"LLMFunctionCall",
|
35
|
+
"LLMFunctionSpec",
|
36
|
+
"Role",
|
37
|
+
"LLMTokenUsage",
|
38
|
+
"LLMResponse",
|
39
|
+
"OpenAIChatModel",
|
40
|
+
"OpenAICompletionModel",
|
41
|
+
"OpenAIGPTConfig",
|
42
|
+
"OpenAIGPT",
|
43
|
+
"AzureConfig",
|
44
|
+
"AzureGPT",
|
45
|
+
]
|
@@ -22,6 +22,9 @@ class AzureConfig(OpenAIGPTConfig):
|
|
22
22
|
chose for your deployment when you deployed a model.
|
23
23
|
model_name (str): can be set in the ``.env`` file as ``AZURE_GPT_MODEL_NAME``
|
24
24
|
and should be based on the model name chosen during setup.
|
25
|
+
model_version (str): can be set in the ``.env`` file as
|
26
|
+
``AZURE_OPENAI_MODEL_VERSION`` and should be based on the model name
|
27
|
+
chosen during setup.
|
25
28
|
"""
|
26
29
|
|
27
30
|
api_key: str = "" # CAUTION: set this ONLY via env var AZURE_OPENAI_API_KEY
|
@@ -29,6 +32,7 @@ class AzureConfig(OpenAIGPTConfig):
|
|
29
32
|
api_version: str = "2023-05-15"
|
30
33
|
deployment_name: str = ""
|
31
34
|
model_name: str = ""
|
35
|
+
model_version: str = "" # is used to determine the cost of using the model
|
32
36
|
api_base: str = ""
|
33
37
|
|
34
38
|
# all of the vars above can be set via env vars,
|
@@ -50,6 +54,7 @@ class AzureGPT(OpenAIGPT):
|
|
50
54
|
api_base (str): Azure API base url
|
51
55
|
api_version (str): Azure API version
|
52
56
|
model_name (str): the name of gpt model in your deployment
|
57
|
+
model_version (str): the version of gpt model in your deployment
|
53
58
|
"""
|
54
59
|
|
55
60
|
def __init__(self, config: AzureConfig):
|
@@ -89,10 +94,7 @@ class AzureGPT(OpenAIGPT):
|
|
89
94
|
# set the chat model to be the same as the model_name
|
90
95
|
# This corresponds to the gpt model you chose for your deployment
|
91
96
|
# when you deployed a model
|
92
|
-
|
93
|
-
self.config.chat_model = OpenAIChatModel.GPT3_5_TURBO
|
94
|
-
else:
|
95
|
-
self.config.chat_model = OpenAIChatModel.GPT4
|
97
|
+
self.set_chat_model()
|
96
98
|
|
97
99
|
self.client = AzureOpenAI(
|
98
100
|
api_key=self.config.api_key,
|
@@ -107,3 +109,44 @@ class AzureGPT(OpenAIGPT):
|
|
107
109
|
azure_deployment=self.config.deployment_name,
|
108
110
|
timeout=Timeout(self.config.timeout),
|
109
111
|
)
|
112
|
+
|
113
|
+
def set_chat_model(self) -> None:
|
114
|
+
"""
|
115
|
+
Sets the chat model configuration based on the model name specified in the
|
116
|
+
``.env``. This function checks the `model_name` in the configuration and sets
|
117
|
+
the appropriate chat model in the `config.chat_model`. It supports handling for
|
118
|
+
'35-turbo' and 'gpt-4' models. For 'gpt-4', it further delegates the handling
|
119
|
+
to `handle_gpt4_model` method. If the model name does not match any predefined
|
120
|
+
models, it defaults to `OpenAIChatModel.GPT4`.
|
121
|
+
"""
|
122
|
+
MODEL_35_TURBO = "35-turbo"
|
123
|
+
MODEL_GPT4 = "gpt-4"
|
124
|
+
|
125
|
+
if self.config.model_name == MODEL_35_TURBO:
|
126
|
+
self.config.chat_model = OpenAIChatModel.GPT3_5_TURBO
|
127
|
+
elif self.config.model_name == MODEL_GPT4:
|
128
|
+
self.handle_gpt4_model()
|
129
|
+
else:
|
130
|
+
self.config.chat_model = OpenAIChatModel.GPT4
|
131
|
+
|
132
|
+
def handle_gpt4_model(self) -> None:
|
133
|
+
"""
|
134
|
+
Handles the setting of the GPT-4 model in the configuration.
|
135
|
+
This function checks the `model_version` in the configuration.
|
136
|
+
If the version is not set, it raises a ValueError indicating that the model
|
137
|
+
version needs to be specified in the ``.env`` file.
|
138
|
+
It sets `OpenAIChatModel.GPT4_TURBO` if the version is
|
139
|
+
'1106-Preview', otherwise, it defaults to setting `OpenAIChatModel.GPT4`.
|
140
|
+
"""
|
141
|
+
VERSION_1106_PREVIEW = "1106-Preview"
|
142
|
+
|
143
|
+
if self.config.model_version == "":
|
144
|
+
raise ValueError(
|
145
|
+
"AZURE_OPENAI_MODEL_VERSION not set in .env file. "
|
146
|
+
"Please set it to the chat model version used in your deployment."
|
147
|
+
)
|
148
|
+
|
149
|
+
if self.config.model_version == VERSION_1106_PREVIEW:
|
150
|
+
self.config.chat_model = OpenAIChatModel.GPT4_TURBO
|
151
|
+
else:
|
152
|
+
self.config.chat_model = OpenAIChatModel.GPT4
|
langroid/language_models/base.py
CHANGED
@@ -3,18 +3,18 @@ import asyncio
|
|
3
3
|
import json
|
4
4
|
import logging
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
+
from datetime import datetime
|
6
7
|
from enum import Enum
|
7
|
-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
8
9
|
|
9
10
|
import aiohttp
|
10
|
-
from pydantic import BaseModel, BaseSettings
|
11
|
+
from pydantic import BaseModel, BaseSettings, Field
|
11
12
|
|
12
13
|
from langroid.cachedb.momento_cachedb import MomentoCacheConfig
|
13
14
|
from langroid.cachedb.redis_cachedb import RedisCacheConfig
|
14
|
-
from langroid.language_models.config import Llama2FormatterConfig, PromptFormatterConfig
|
15
15
|
from langroid.mytypes import Document
|
16
16
|
from langroid.parsing.agent_chats import parse_message
|
17
|
-
from langroid.parsing.
|
17
|
+
from langroid.parsing.parse_json import top_level_json_field
|
18
18
|
from langroid.prompts.dialog import collate_chat_history
|
19
19
|
from langroid.prompts.templates import (
|
20
20
|
EXTRACTION_PROMPT_GPT4,
|
@@ -26,16 +26,21 @@ from langroid.utils.output.printing import show_if_debug
|
|
26
26
|
logger = logging.getLogger(__name__)
|
27
27
|
|
28
28
|
|
29
|
+
def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
|
30
|
+
pass
|
31
|
+
|
32
|
+
|
29
33
|
class LLMConfig(BaseSettings):
|
30
34
|
type: str = "openai"
|
35
|
+
streamer: Optional[Callable[[Any], None]] = noop_fn
|
31
36
|
api_base: str | None = None
|
32
|
-
formatter: None |
|
37
|
+
formatter: None | str = None
|
33
38
|
timeout: int = 20 # timeout for API requests
|
34
39
|
chat_model: str = ""
|
35
40
|
completion_model: str = ""
|
36
41
|
temperature: float = 0.0
|
37
|
-
chat_context_length: int =
|
38
|
-
completion_context_length: int =
|
42
|
+
chat_context_length: int = 8000
|
43
|
+
completion_context_length: int = 8000
|
39
44
|
max_output_tokens: int = 1024 # generate at most this many tokens
|
40
45
|
# if input length + max_output_tokens > context length of model,
|
41
46
|
# we will try shortening requested output
|
@@ -72,8 +77,11 @@ class LLMFunctionCall(BaseModel):
|
|
72
77
|
fun_args_str = message["arguments"]
|
73
78
|
# sometimes may be malformed with invalid indents,
|
74
79
|
# so we try to be safe by removing newlines.
|
75
|
-
fun_args_str
|
76
|
-
|
80
|
+
if fun_args_str is not None:
|
81
|
+
fun_args_str = fun_args_str.replace("\n", "").strip()
|
82
|
+
fun_args = ast.literal_eval(fun_args_str)
|
83
|
+
else:
|
84
|
+
fun_args = None
|
77
85
|
fun_call.arguments = fun_args
|
78
86
|
|
79
87
|
return fun_call
|
@@ -135,6 +143,7 @@ class LLMMessage(BaseModel):
|
|
135
143
|
tool_id: str = "" # used by OpenAIAssistant
|
136
144
|
content: str
|
137
145
|
function_call: Optional[LLMFunctionCall] = None
|
146
|
+
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
138
147
|
|
139
148
|
def api_dict(self) -> Dict[str, Any]:
|
140
149
|
"""
|
@@ -157,6 +166,7 @@ class LLMMessage(BaseModel):
|
|
157
166
|
dict_no_none["function_call"]["arguments"]
|
158
167
|
)
|
159
168
|
dict_no_none.pop("tool_id", None)
|
169
|
+
dict_no_none.pop("timestamp", None)
|
160
170
|
return dict_no_none
|
161
171
|
|
162
172
|
def __str__(self) -> str:
|
@@ -179,6 +189,12 @@ class LLMResponse(BaseModel):
|
|
179
189
|
usage: Optional[LLMTokenUsage]
|
180
190
|
cached: bool = False
|
181
191
|
|
192
|
+
def __str__(self) -> str:
|
193
|
+
if self.function_call is not None:
|
194
|
+
return str(self.function_call)
|
195
|
+
else:
|
196
|
+
return self.message
|
197
|
+
|
182
198
|
def to_LLMMessage(self) -> LLMMessage:
|
183
199
|
content = self.message
|
184
200
|
role = Role.ASSISTANT if self.function_call is None else Role.FUNCTION
|
@@ -245,7 +261,7 @@ class LanguageModel(ABC):
|
|
245
261
|
# usage cost by model, accumulates here
|
246
262
|
usage_cost_dict: Dict[str, LLMTokenUsage] = {}
|
247
263
|
|
248
|
-
def __init__(self, config: LLMConfig):
|
264
|
+
def __init__(self, config: LLMConfig = LLMConfig()):
|
249
265
|
self.config = config
|
250
266
|
|
251
267
|
@staticmethod
|