langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. langroid/__init__.py +70 -0
  2. langroid/agent/__init__.py +22 -0
  3. langroid/agent/base.py +120 -33
  4. langroid/agent/batch.py +134 -35
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +608 -0
  7. langroid/agent/chat_agent.py +164 -100
  8. langroid/agent/chat_document.py +19 -2
  9. langroid/agent/openai_assistant.py +20 -10
  10. langroid/agent/special/__init__.py +33 -10
  11. langroid/agent/special/doc_chat_agent.py +521 -108
  12. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  13. langroid/agent/special/lance_rag/__init__.py +9 -0
  14. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  15. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  16. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  17. langroid/agent/special/lance_tools.py +44 -0
  18. langroid/agent/special/neo4j/__init__.py +0 -0
  19. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  20. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  21. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  22. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  23. langroid/agent/special/relevance_extractor_agent.py +23 -7
  24. langroid/agent/special/retriever_agent.py +29 -174
  25. langroid/agent/special/sql/__init__.py +7 -0
  26. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  27. langroid/agent/special/sql/utils/__init__.py +11 -0
  28. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  29. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  30. langroid/agent/special/table_chat_agent.py +43 -9
  31. langroid/agent/task.py +423 -114
  32. langroid/agent/tool_message.py +67 -10
  33. langroid/agent/tools/__init__.py +8 -0
  34. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  35. langroid/agent/tools/google_search_tool.py +11 -0
  36. langroid/agent/tools/metaphor_search_tool.py +67 -0
  37. langroid/agent/tools/recipient_tool.py +6 -24
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/cachedb/__init__.py +6 -0
  40. langroid/embedding_models/__init__.py +24 -0
  41. langroid/embedding_models/base.py +9 -1
  42. langroid/embedding_models/models.py +117 -17
  43. langroid/embedding_models/protoc/embeddings.proto +19 -0
  44. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  45. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  46. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  47. langroid/embedding_models/remote_embeds.py +153 -0
  48. langroid/language_models/__init__.py +22 -0
  49. langroid/language_models/azure_openai.py +47 -4
  50. langroid/language_models/base.py +26 -10
  51. langroid/language_models/config.py +5 -0
  52. langroid/language_models/openai_gpt.py +407 -121
  53. langroid/language_models/prompt_formatter/__init__.py +9 -0
  54. langroid/language_models/prompt_formatter/base.py +4 -6
  55. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  56. langroid/language_models/utils.py +10 -9
  57. langroid/mytypes.py +10 -4
  58. langroid/parsing/__init__.py +33 -1
  59. langroid/parsing/document_parser.py +259 -63
  60. langroid/parsing/image_text.py +32 -0
  61. langroid/parsing/parse_json.py +143 -0
  62. langroid/parsing/parser.py +20 -7
  63. langroid/parsing/repo_loader.py +108 -46
  64. langroid/parsing/search.py +8 -0
  65. langroid/parsing/table_loader.py +44 -0
  66. langroid/parsing/url_loader.py +59 -13
  67. langroid/parsing/urls.py +18 -9
  68. langroid/parsing/utils.py +130 -9
  69. langroid/parsing/web_search.py +73 -0
  70. langroid/prompts/__init__.py +7 -0
  71. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  72. langroid/prompts/prompts_config.py +1 -1
  73. langroid/utils/__init__.py +10 -0
  74. langroid/utils/algorithms/__init__.py +3 -0
  75. langroid/utils/configuration.py +0 -1
  76. langroid/utils/constants.py +4 -0
  77. langroid/utils/logging.py +2 -5
  78. langroid/utils/output/__init__.py +15 -2
  79. langroid/utils/output/status.py +33 -0
  80. langroid/utils/pandas_utils.py +30 -0
  81. langroid/utils/pydantic_utils.py +446 -4
  82. langroid/utils/system.py +36 -1
  83. langroid/vector_store/__init__.py +34 -2
  84. langroid/vector_store/base.py +33 -2
  85. langroid/vector_store/chromadb.py +42 -13
  86. langroid/vector_store/lancedb.py +226 -60
  87. langroid/vector_store/meilisearch.py +7 -6
  88. langroid/vector_store/momento.py +3 -2
  89. langroid/vector_store/qdrantdb.py +82 -11
  90. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
  91. langroid-0.1.219.dist-info/RECORD +127 -0
  92. langroid/agent/special/recipient_validator_agent.py +0 -157
  93. langroid/parsing/json.py +0 -64
  94. langroid/utils/web/selenium_login.py +0 -36
  95. langroid-0.1.139.dist-info/RECORD +0 -103
  96. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
  97. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -0,0 +1,33 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: embeddings.proto
4
+ # Protobuf Python Version: 4.25.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+
11
+ # @@protoc_insertion_point(imports)
12
+
13
+ _sym_db = _symbol_database.Default()
14
+
15
+
16
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
+ b'\n\x10\x65mbeddings.proto"K\n\x10\x45mbeddingRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x12\n\nbatch_size\x18\x02 \x01(\x05\x12\x0f\n\x07strings\x18\x03 \x03(\t"%\n\x0b\x42\x61tchEmbeds\x12\x16\n\x06\x65mbeds\x18\x01 \x03(\x0b\x32\x06.Embed"\x16\n\x05\x45mbed\x12\r\n\x05\x65mbed\x18\x01 \x03(\x02\x32\x37\n\tEmbedding\x12*\n\x05\x45mbed\x12\x11.EmbeddingRequest\x1a\x0c.BatchEmbeds"\x00\x62\x06proto3'
18
+ )
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "embeddings_pb2", _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals["_EMBEDDINGREQUEST"]._serialized_start = 20
26
+ _globals["_EMBEDDINGREQUEST"]._serialized_end = 95
27
+ _globals["_BATCHEMBEDS"]._serialized_start = 97
28
+ _globals["_BATCHEMBEDS"]._serialized_end = 134
29
+ _globals["_EMBED"]._serialized_start = 136
30
+ _globals["_EMBED"]._serialized_end = 158
31
+ _globals["_EMBEDDING"]._serialized_start = 160
32
+ _globals["_EMBEDDING"]._serialized_end = 215
33
+ # @@protoc_insertion_point(module_scope)
@@ -0,0 +1,50 @@
1
+ from typing import (
2
+ ClassVar as _ClassVar,
3
+ )
4
+ from typing import (
5
+ Iterable as _Iterable,
6
+ )
7
+ from typing import (
8
+ Mapping as _Mapping,
9
+ )
10
+ from typing import (
11
+ Optional as _Optional,
12
+ )
13
+ from typing import (
14
+ Union as _Union,
15
+ )
16
+
17
+ from google.protobuf import descriptor as _descriptor
18
+ from google.protobuf import message as _message
19
+ from google.protobuf.internal import containers as _containers
20
+
21
+ DESCRIPTOR: _descriptor.FileDescriptor
22
+
23
+ class EmbeddingRequest(_message.Message):
24
+ __slots__ = ("model_name", "batch_size", "strings")
25
+ MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
26
+ BATCH_SIZE_FIELD_NUMBER: _ClassVar[int]
27
+ STRINGS_FIELD_NUMBER: _ClassVar[int]
28
+ model_name: str
29
+ batch_size: int
30
+ strings: _containers.RepeatedScalarFieldContainer[str]
31
+ def __init__(
32
+ self,
33
+ model_name: _Optional[str] = ...,
34
+ batch_size: _Optional[int] = ...,
35
+ strings: _Optional[_Iterable[str]] = ...,
36
+ ) -> None: ...
37
+
38
+ class BatchEmbeds(_message.Message):
39
+ __slots__ = ("embeds",)
40
+ EMBEDS_FIELD_NUMBER: _ClassVar[int]
41
+ embeds: _containers.RepeatedCompositeFieldContainer[Embed]
42
+ def __init__(
43
+ self, embeds: _Optional[_Iterable[_Union[Embed, _Mapping]]] = ...
44
+ ) -> None: ...
45
+
46
+ class Embed(_message.Message):
47
+ __slots__ = ("embed",)
48
+ EMBED_FIELD_NUMBER: _ClassVar[int]
49
+ embed: _containers.RepeatedScalarFieldContainer[float]
50
+ def __init__(self, embed: _Optional[_Iterable[float]] = ...) -> None: ...
@@ -0,0 +1,79 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
5
+ import langroid.embedding_models.protoc.embeddings_pb2 as embeddings__pb2
6
+
7
+
8
+ class EmbeddingStub(object):
9
+ """Missing associated documentation comment in .proto file."""
10
+
11
+ def __init__(self, channel):
12
+ """Constructor.
13
+
14
+ Args:
15
+ channel: A grpc.Channel.
16
+ """
17
+ self.Embed = channel.unary_unary(
18
+ "/Embedding/Embed",
19
+ request_serializer=embeddings__pb2.EmbeddingRequest.SerializeToString,
20
+ response_deserializer=embeddings__pb2.BatchEmbeds.FromString,
21
+ )
22
+
23
+
24
+ class EmbeddingServicer(object):
25
+ """Missing associated documentation comment in .proto file."""
26
+
27
+ def Embed(self, request, context):
28
+ """Missing associated documentation comment in .proto file."""
29
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
30
+ context.set_details("Method not implemented!")
31
+ raise NotImplementedError("Method not implemented!")
32
+
33
+
34
+ def add_EmbeddingServicer_to_server(servicer, server):
35
+ rpc_method_handlers = {
36
+ "Embed": grpc.unary_unary_rpc_method_handler(
37
+ servicer.Embed,
38
+ request_deserializer=embeddings__pb2.EmbeddingRequest.FromString,
39
+ response_serializer=embeddings__pb2.BatchEmbeds.SerializeToString,
40
+ ),
41
+ }
42
+ generic_handler = grpc.method_handlers_generic_handler(
43
+ "Embedding", rpc_method_handlers
44
+ )
45
+ server.add_generic_rpc_handlers((generic_handler,))
46
+
47
+
48
+ # This class is part of an EXPERIMENTAL API.
49
+ class Embedding(object):
50
+ """Missing associated documentation comment in .proto file."""
51
+
52
+ @staticmethod
53
+ def Embed(
54
+ request,
55
+ target,
56
+ options=(),
57
+ channel_credentials=None,
58
+ call_credentials=None,
59
+ insecure=False,
60
+ compression=None,
61
+ wait_for_ready=None,
62
+ timeout=None,
63
+ metadata=None,
64
+ ):
65
+ return grpc.experimental.unary_unary(
66
+ request,
67
+ target,
68
+ "/Embedding/Embed",
69
+ embeddings__pb2.EmbeddingRequest.SerializeToString,
70
+ embeddings__pb2.BatchEmbeds.FromString,
71
+ options,
72
+ channel_credentials,
73
+ insecure,
74
+ call_credentials,
75
+ compression,
76
+ wait_for_ready,
77
+ timeout,
78
+ metadata,
79
+ )
@@ -0,0 +1,153 @@
1
+ """
2
+ If run as a script, starts an RPC server which handles remote
3
+ embedding requests:
4
+
5
+ For example:
6
+ python3 -m langroid.embedding_models.remote_embeds --port `port`
7
+
8
+ where `port` is the port at which the service is exposed. Currently,
9
+ supports insecure connections only, and this should NOT be exposed to
10
+ the internet.
11
+ """
12
+
13
+ import atexit
14
+ import subprocess
15
+ import time
16
+ from typing import Callable, Optional
17
+
18
+ import grpc
19
+ from fire import Fire
20
+
21
+ import langroid.embedding_models.models as em
22
+ import langroid.embedding_models.protoc.embeddings_pb2 as embeddings_pb
23
+ import langroid.embedding_models.protoc.embeddings_pb2_grpc as embeddings_grpc
24
+ from langroid.mytypes import Embeddings
25
+
26
+
27
+ class RemoteEmbeddingRPCs(embeddings_grpc.EmbeddingServicer):
28
+ def __init__(
29
+ self,
30
+ model_name: str,
31
+ batch_size: int,
32
+ data_parallel: bool,
33
+ device: Optional[str],
34
+ devices: Optional[list[str]],
35
+ ):
36
+ super().__init__()
37
+ self.embedding_fn = em.SentenceTransformerEmbeddings(
38
+ em.SentenceTransformerEmbeddingsConfig(
39
+ model_name=model_name,
40
+ batch_size=batch_size,
41
+ data_parallel=data_parallel,
42
+ device=device,
43
+ devices=devices,
44
+ )
45
+ ).embedding_fn()
46
+
47
+ def Embed(
48
+ self, request: embeddings_pb.EmbeddingRequest, _: grpc.RpcContext
49
+ ) -> embeddings_pb.BatchEmbeds:
50
+ embeds = self.embedding_fn(list(request.strings))
51
+
52
+ embeds_pb = [embeddings_pb.Embed(embed=e) for e in embeds]
53
+
54
+ return embeddings_pb.BatchEmbeds(embeds=embeds_pb)
55
+
56
+
57
+ class RemoteEmbeddingsConfig(em.SentenceTransformerEmbeddingsConfig):
58
+ api_base: str = "localhost"
59
+ port: int = 50052
60
+ # The below are used only when waiting for server creation
61
+ poll_delay: float = 0.01
62
+ max_retries: int = 1000
63
+
64
+
65
+ class RemoteEmbeddings(em.SentenceTransformerEmbeddings):
66
+ def __init__(self, config: RemoteEmbeddingsConfig = RemoteEmbeddingsConfig()):
67
+ super().__init__(config)
68
+ self.config: RemoteEmbeddingsConfig = config
69
+ self.have_started_server: bool = False
70
+
71
+ def embedding_fn(self) -> Callable[[list[str]], Embeddings]:
72
+ def fn(texts: list[str]) -> Embeddings:
73
+ url = f"{self.config.api_base}:{self.config.port}"
74
+ with grpc.insecure_channel(url) as channel:
75
+ stub = embeddings_grpc.EmbeddingStub(channel) # type: ignore
76
+ response = stub.Embed(
77
+ embeddings_pb.EmbeddingRequest(
78
+ strings=texts,
79
+ )
80
+ )
81
+
82
+ return [list(emb.embed) for emb in response.embeds]
83
+
84
+ def with_handling(texts: list[str]) -> Embeddings:
85
+ # In local mode, start the server if it has not already
86
+ # been started
87
+ if self.config.api_base == "localhost" and not self.have_started_server:
88
+ try:
89
+ return fn(texts)
90
+ # Occurs when the server hasn't been started
91
+ except grpc.RpcError:
92
+ self.have_started_server = True
93
+ # Start the server
94
+ proc = subprocess.Popen(
95
+ [
96
+ "python3",
97
+ __file__,
98
+ "--bind_address_base",
99
+ self.config.api_base,
100
+ "--port",
101
+ str(self.config.port),
102
+ "--batch_size",
103
+ str(self.config.batch_size),
104
+ "--model_name",
105
+ self.config.model_name,
106
+ ],
107
+ )
108
+
109
+ atexit.register(lambda: proc.terminate())
110
+
111
+ for _ in range(self.config.max_retries - 1):
112
+ try:
113
+ return fn(texts)
114
+ except grpc.RpcError:
115
+ time.sleep(self.config.poll_delay)
116
+
117
+ # The remote is not local or we have exhausted retries
118
+ # We should now raise an error if the server is not accessible
119
+ return fn(texts)
120
+
121
+ return with_handling
122
+
123
+
124
+ async def serve(
125
+ bind_address_base: str = "localhost",
126
+ port: int = 50052,
127
+ batch_size: int = 512,
128
+ data_parallel: bool = False,
129
+ device: Optional[str] = None,
130
+ devices: Optional[list[str]] = None,
131
+ model_name: str = "BAAI/bge-large-en-v1.5",
132
+ ) -> None:
133
+ """Starts the RPC server."""
134
+ server = grpc.aio.server()
135
+ embeddings_grpc.add_EmbeddingServicer_to_server(
136
+ RemoteEmbeddingRPCs(
137
+ model_name=model_name,
138
+ batch_size=batch_size,
139
+ data_parallel=data_parallel,
140
+ device=device,
141
+ devices=devices,
142
+ ),
143
+ server,
144
+ ) # type: ignore
145
+ url = f"{bind_address_base}:{port}"
146
+ server.add_insecure_port(url)
147
+ await server.start()
148
+ print(f"Embedding server started, listening on {url}")
149
+ await server.wait_for_termination()
150
+
151
+
152
+ if __name__ == "__main__":
153
+ Fire(serve)
@@ -21,3 +21,25 @@ from .openai_gpt import (
21
21
  OpenAIGPT,
22
22
  )
23
23
  from .azure_openai import AzureConfig, AzureGPT
24
+
25
+ __all__ = [
26
+ "utils",
27
+ "config",
28
+ "base",
29
+ "openai_gpt",
30
+ "azure_openai",
31
+ "prompt_formatter",
32
+ "LLMConfig",
33
+ "LLMMessage",
34
+ "LLMFunctionCall",
35
+ "LLMFunctionSpec",
36
+ "Role",
37
+ "LLMTokenUsage",
38
+ "LLMResponse",
39
+ "OpenAIChatModel",
40
+ "OpenAICompletionModel",
41
+ "OpenAIGPTConfig",
42
+ "OpenAIGPT",
43
+ "AzureConfig",
44
+ "AzureGPT",
45
+ ]
@@ -22,6 +22,9 @@ class AzureConfig(OpenAIGPTConfig):
22
22
  chose for your deployment when you deployed a model.
23
23
  model_name (str): can be set in the ``.env`` file as ``AZURE_GPT_MODEL_NAME``
24
24
  and should be based on the model name chosen during setup.
25
+ model_version (str): can be set in the ``.env`` file as
26
+ ``AZURE_OPENAI_MODEL_VERSION`` and should be based on the model name
27
+ chosen during setup.
25
28
  """
26
29
 
27
30
  api_key: str = "" # CAUTION: set this ONLY via env var AZURE_OPENAI_API_KEY
@@ -29,6 +32,7 @@ class AzureConfig(OpenAIGPTConfig):
29
32
  api_version: str = "2023-05-15"
30
33
  deployment_name: str = ""
31
34
  model_name: str = ""
35
+ model_version: str = "" # is used to determine the cost of using the model
32
36
  api_base: str = ""
33
37
 
34
38
  # all of the vars above can be set via env vars,
@@ -50,6 +54,7 @@ class AzureGPT(OpenAIGPT):
50
54
  api_base (str): Azure API base url
51
55
  api_version (str): Azure API version
52
56
  model_name (str): the name of gpt model in your deployment
57
+ model_version (str): the version of gpt model in your deployment
53
58
  """
54
59
 
55
60
  def __init__(self, config: AzureConfig):
@@ -89,10 +94,7 @@ class AzureGPT(OpenAIGPT):
89
94
  # set the chat model to be the same as the model_name
90
95
  # This corresponds to the gpt model you chose for your deployment
91
96
  # when you deployed a model
92
- if "35-turbo" in self.config.model_name:
93
- self.config.chat_model = OpenAIChatModel.GPT3_5_TURBO
94
- else:
95
- self.config.chat_model = OpenAIChatModel.GPT4
97
+ self.set_chat_model()
96
98
 
97
99
  self.client = AzureOpenAI(
98
100
  api_key=self.config.api_key,
@@ -107,3 +109,44 @@ class AzureGPT(OpenAIGPT):
107
109
  azure_deployment=self.config.deployment_name,
108
110
  timeout=Timeout(self.config.timeout),
109
111
  )
112
+
113
+ def set_chat_model(self) -> None:
114
+ """
115
+ Sets the chat model configuration based on the model name specified in the
116
+ ``.env``. This function checks the `model_name` in the configuration and sets
117
+ the appropriate chat model in the `config.chat_model`. It supports handling for
118
+ '35-turbo' and 'gpt-4' models. For 'gpt-4', it further delegates the handling
119
+ to `handle_gpt4_model` method. If the model name does not match any predefined
120
+ models, it defaults to `OpenAIChatModel.GPT4`.
121
+ """
122
+ MODEL_35_TURBO = "35-turbo"
123
+ MODEL_GPT4 = "gpt-4"
124
+
125
+ if self.config.model_name == MODEL_35_TURBO:
126
+ self.config.chat_model = OpenAIChatModel.GPT3_5_TURBO
127
+ elif self.config.model_name == MODEL_GPT4:
128
+ self.handle_gpt4_model()
129
+ else:
130
+ self.config.chat_model = OpenAIChatModel.GPT4
131
+
132
+ def handle_gpt4_model(self) -> None:
133
+ """
134
+ Handles the setting of the GPT-4 model in the configuration.
135
+ This function checks the `model_version` in the configuration.
136
+ If the version is not set, it raises a ValueError indicating that the model
137
+ version needs to be specified in the ``.env`` file.
138
+ It sets `OpenAIChatModel.GPT4_TURBO` if the version is
139
+ '1106-Preview', otherwise, it defaults to setting `OpenAIChatModel.GPT4`.
140
+ """
141
+ VERSION_1106_PREVIEW = "1106-Preview"
142
+
143
+ if self.config.model_version == "":
144
+ raise ValueError(
145
+ "AZURE_OPENAI_MODEL_VERSION not set in .env file. "
146
+ "Please set it to the chat model version used in your deployment."
147
+ )
148
+
149
+ if self.config.model_version == VERSION_1106_PREVIEW:
150
+ self.config.chat_model = OpenAIChatModel.GPT4_TURBO
151
+ else:
152
+ self.config.chat_model = OpenAIChatModel.GPT4
@@ -3,18 +3,18 @@ import asyncio
3
3
  import json
4
4
  import logging
5
5
  from abc import ABC, abstractmethod
6
+ from datetime import datetime
6
7
  from enum import Enum
7
- from typing import Any, Dict, List, Optional, Tuple, Type, Union
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
8
9
 
9
10
  import aiohttp
10
- from pydantic import BaseModel, BaseSettings
11
+ from pydantic import BaseModel, BaseSettings, Field
11
12
 
12
13
  from langroid.cachedb.momento_cachedb import MomentoCacheConfig
13
14
  from langroid.cachedb.redis_cachedb import RedisCacheConfig
14
- from langroid.language_models.config import Llama2FormatterConfig, PromptFormatterConfig
15
15
  from langroid.mytypes import Document
16
16
  from langroid.parsing.agent_chats import parse_message
17
- from langroid.parsing.json import top_level_json_field
17
+ from langroid.parsing.parse_json import top_level_json_field
18
18
  from langroid.prompts.dialog import collate_chat_history
19
19
  from langroid.prompts.templates import (
20
20
  EXTRACTION_PROMPT_GPT4,
@@ -26,16 +26,21 @@ from langroid.utils.output.printing import show_if_debug
26
26
  logger = logging.getLogger(__name__)
27
27
 
28
28
 
29
+ def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
30
+ pass
31
+
32
+
29
33
  class LLMConfig(BaseSettings):
30
34
  type: str = "openai"
35
+ streamer: Optional[Callable[[Any], None]] = noop_fn
31
36
  api_base: str | None = None
32
- formatter: None | PromptFormatterConfig = Llama2FormatterConfig()
37
+ formatter: None | str = None
33
38
  timeout: int = 20 # timeout for API requests
34
39
  chat_model: str = ""
35
40
  completion_model: str = ""
36
41
  temperature: float = 0.0
37
- chat_context_length: int = 1024
38
- completion_context_length: int = 1024
42
+ chat_context_length: int = 8000
43
+ completion_context_length: int = 8000
39
44
  max_output_tokens: int = 1024 # generate at most this many tokens
40
45
  # if input length + max_output_tokens > context length of model,
41
46
  # we will try shortening requested output
@@ -72,8 +77,11 @@ class LLMFunctionCall(BaseModel):
72
77
  fun_args_str = message["arguments"]
73
78
  # sometimes may be malformed with invalid indents,
74
79
  # so we try to be safe by removing newlines.
75
- fun_args_str = fun_args_str.replace("\n", "").strip()
76
- fun_args = ast.literal_eval(fun_args_str)
80
+ if fun_args_str is not None:
81
+ fun_args_str = fun_args_str.replace("\n", "").strip()
82
+ fun_args = ast.literal_eval(fun_args_str)
83
+ else:
84
+ fun_args = None
77
85
  fun_call.arguments = fun_args
78
86
 
79
87
  return fun_call
@@ -135,6 +143,7 @@ class LLMMessage(BaseModel):
135
143
  tool_id: str = "" # used by OpenAIAssistant
136
144
  content: str
137
145
  function_call: Optional[LLMFunctionCall] = None
146
+ timestamp: datetime = Field(default_factory=datetime.utcnow)
138
147
 
139
148
  def api_dict(self) -> Dict[str, Any]:
140
149
  """
@@ -157,6 +166,7 @@ class LLMMessage(BaseModel):
157
166
  dict_no_none["function_call"]["arguments"]
158
167
  )
159
168
  dict_no_none.pop("tool_id", None)
169
+ dict_no_none.pop("timestamp", None)
160
170
  return dict_no_none
161
171
 
162
172
  def __str__(self) -> str:
@@ -179,6 +189,12 @@ class LLMResponse(BaseModel):
179
189
  usage: Optional[LLMTokenUsage]
180
190
  cached: bool = False
181
191
 
192
+ def __str__(self) -> str:
193
+ if self.function_call is not None:
194
+ return str(self.function_call)
195
+ else:
196
+ return self.message
197
+
182
198
  def to_LLMMessage(self) -> LLMMessage:
183
199
  content = self.message
184
200
  role = Role.ASSISTANT if self.function_call is None else Role.FUNCTION
@@ -245,7 +261,7 @@ class LanguageModel(ABC):
245
261
  # usage cost by model, accumulates here
246
262
  usage_cost_dict: Dict[str, LLMTokenUsage] = {}
247
263
 
248
- def __init__(self, config: LLMConfig):
264
+ def __init__(self, config: LLMConfig = LLMConfig()):
249
265
  self.config = config
250
266
 
251
267
  @staticmethod
@@ -11,3 +11,8 @@ class PromptFormatterConfig(BaseSettings):
11
11
 
12
12
  class Llama2FormatterConfig(PromptFormatterConfig):
13
13
  use_bos_eos: bool = False
14
+
15
+
16
+ class HFPromptFormatterConfig(PromptFormatterConfig):
17
+ type: str = "hf"
18
+ model_name: str