langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -0,0 +1,153 @@
1
+ """
2
+ If run as a script, starts an RPC server which handles remote
3
+ embedding requests:
4
+
5
+ For example:
6
+ python3 -m langroid.embedding_models.remote_embeds --port `port`
7
+
8
+ where `port` is the port at which the service is exposed. Currently,
9
+ supports insecure connections only, and this should NOT be exposed to
10
+ the internet.
11
+ """
12
+
13
+ import atexit
14
+ import subprocess
15
+ import time
16
+ from typing import Callable, Optional
17
+
18
+ import grpc
19
+ from fire import Fire
20
+
21
+ import langroid.embedding_models.models as em
22
+ import langroid.embedding_models.protoc.embeddings_pb2 as embeddings_pb
23
+ import langroid.embedding_models.protoc.embeddings_pb2_grpc as embeddings_grpc
24
+ from langroid.mytypes import Embeddings
25
+
26
+
27
+ class RemoteEmbeddingRPCs(embeddings_grpc.EmbeddingServicer):
28
+ def __init__(
29
+ self,
30
+ model_name: str,
31
+ batch_size: int,
32
+ data_parallel: bool,
33
+ device: Optional[str],
34
+ devices: Optional[list[str]],
35
+ ):
36
+ super().__init__()
37
+ self.embedding_fn = em.SentenceTransformerEmbeddings(
38
+ em.SentenceTransformerEmbeddingsConfig(
39
+ model_name=model_name,
40
+ batch_size=batch_size,
41
+ data_parallel=data_parallel,
42
+ device=device,
43
+ devices=devices,
44
+ )
45
+ ).embedding_fn()
46
+
47
+ def Embed(
48
+ self, request: embeddings_pb.EmbeddingRequest, _: grpc.RpcContext
49
+ ) -> embeddings_pb.BatchEmbeds:
50
+ embeds = self.embedding_fn(list(request.strings))
51
+
52
+ embeds_pb = [embeddings_pb.Embed(embed=e) for e in embeds]
53
+
54
+ return embeddings_pb.BatchEmbeds(embeds=embeds_pb)
55
+
56
+
57
+ class RemoteEmbeddingsConfig(em.SentenceTransformerEmbeddingsConfig):
58
+ api_base: str = "localhost"
59
+ port: int = 50052
60
+ # The below are used only when waiting for server creation
61
+ poll_delay: float = 0.01
62
+ max_retries: int = 1000
63
+
64
+
65
+ class RemoteEmbeddings(em.SentenceTransformerEmbeddings):
66
+ def __init__(self, config: RemoteEmbeddingsConfig = RemoteEmbeddingsConfig()):
67
+ super().__init__(config)
68
+ self.config: RemoteEmbeddingsConfig = config
69
+ self.have_started_server: bool = False
70
+
71
+ def embedding_fn(self) -> Callable[[list[str]], Embeddings]:
72
+ def fn(texts: list[str]) -> Embeddings:
73
+ url = f"{self.config.api_base}:{self.config.port}"
74
+ with grpc.insecure_channel(url) as channel:
75
+ stub = embeddings_grpc.EmbeddingStub(channel) # type: ignore
76
+ response = stub.Embed(
77
+ embeddings_pb.EmbeddingRequest(
78
+ strings=texts,
79
+ )
80
+ )
81
+
82
+ return [list(emb.embed) for emb in response.embeds]
83
+
84
+ def with_handling(texts: list[str]) -> Embeddings:
85
+ # In local mode, start the server if it has not already
86
+ # been started
87
+ if self.config.api_base == "localhost" and not self.have_started_server:
88
+ try:
89
+ return fn(texts)
90
+ # Occurs when the server hasn't been started
91
+ except grpc.RpcError:
92
+ self.have_started_server = True
93
+ # Start the server
94
+ proc = subprocess.Popen(
95
+ [
96
+ "python3",
97
+ __file__,
98
+ "--bind_address_base",
99
+ self.config.api_base,
100
+ "--port",
101
+ str(self.config.port),
102
+ "--batch_size",
103
+ str(self.config.batch_size),
104
+ "--model_name",
105
+ self.config.model_name,
106
+ ],
107
+ )
108
+
109
+ atexit.register(lambda: proc.terminate())
110
+
111
+ for _ in range(self.config.max_retries - 1):
112
+ try:
113
+ return fn(texts)
114
+ except grpc.RpcError:
115
+ time.sleep(self.config.poll_delay)
116
+
117
+ # The remote is not local or we have exhausted retries
118
+ # We should now raise an error if the server is not accessible
119
+ return fn(texts)
120
+
121
+ return with_handling
122
+
123
+
124
+ async def serve(
125
+ bind_address_base: str = "localhost",
126
+ port: int = 50052,
127
+ batch_size: int = 512,
128
+ data_parallel: bool = False,
129
+ device: Optional[str] = None,
130
+ devices: Optional[list[str]] = None,
131
+ model_name: str = "BAAI/bge-large-en-v1.5",
132
+ ) -> None:
133
+ """Starts the RPC server."""
134
+ server = grpc.aio.server()
135
+ embeddings_grpc.add_EmbeddingServicer_to_server(
136
+ RemoteEmbeddingRPCs(
137
+ model_name=model_name,
138
+ batch_size=batch_size,
139
+ data_parallel=data_parallel,
140
+ device=device,
141
+ devices=devices,
142
+ ),
143
+ server,
144
+ ) # type: ignore
145
+ url = f"{bind_address_base}:{port}"
146
+ server.add_insecure_port(url)
147
+ await server.start()
148
+ print(f"Embedding server started, listening on {url}")
149
+ await server.wait_for_termination()
150
+
151
+
152
+ if __name__ == "__main__":
153
+ Fire(serve)
@@ -0,0 +1,45 @@
1
+ from . import utils
2
+ from . import config
3
+ from . import base
4
+ from . import openai_gpt
5
+ from . import azure_openai
6
+ from . import prompt_formatter
7
+
8
+ from .base import (
9
+ LLMConfig,
10
+ LLMMessage,
11
+ LLMFunctionCall,
12
+ LLMFunctionSpec,
13
+ Role,
14
+ LLMTokenUsage,
15
+ LLMResponse,
16
+ )
17
+ from .openai_gpt import (
18
+ OpenAIChatModel,
19
+ OpenAICompletionModel,
20
+ OpenAIGPTConfig,
21
+ OpenAIGPT,
22
+ )
23
+ from .azure_openai import AzureConfig, AzureGPT
24
+
25
+ __all__ = [
26
+ "utils",
27
+ "config",
28
+ "base",
29
+ "openai_gpt",
30
+ "azure_openai",
31
+ "prompt_formatter",
32
+ "LLMConfig",
33
+ "LLMMessage",
34
+ "LLMFunctionCall",
35
+ "LLMFunctionSpec",
36
+ "Role",
37
+ "LLMTokenUsage",
38
+ "LLMResponse",
39
+ "OpenAIChatModel",
40
+ "OpenAICompletionModel",
41
+ "OpenAIGPTConfig",
42
+ "OpenAIGPT",
43
+ "AzureConfig",
44
+ "AzureGPT",
45
+ ]
@@ -1,7 +1,6 @@
1
- import os
2
-
3
- import openai
4
1
  from dotenv import load_dotenv
2
+ from httpx import Timeout
3
+ from openai import AsyncAzureOpenAI, AzureOpenAI
5
4
 
6
5
  from langroid.language_models.openai_gpt import (
7
6
  OpenAIChatModel,
@@ -23,12 +22,26 @@ class AzureConfig(OpenAIGPTConfig):
23
22
  chose for your deployment when you deployed a model.
24
23
  model_name (str): can be set in the ``.env`` file as ``AZURE_GPT_MODEL_NAME``
25
24
  and should be based on the model name chosen during setup.
25
+ model_version (str): can be set in the ``.env`` file as
26
+ ``AZURE_OPENAI_MODEL_VERSION`` and should be based on the model name
27
+ chosen during setup.
26
28
  """
27
29
 
30
+ api_key: str = "" # CAUTION: set this ONLY via env var AZURE_OPENAI_API_KEY
28
31
  type: str = "azure"
29
32
  api_version: str = "2023-05-15"
30
33
  deployment_name: str = ""
31
34
  model_name: str = ""
35
+ model_version: str = "" # is used to determine the cost of using the model
36
+ api_base: str = ""
37
+
38
+ # all of the vars above can be set via env vars,
39
+ # by upper-casing the name and prefixing with `env_prefix`, e.g.
40
+ # AZURE_OPENAI_API_VERSION=2023-05-15
41
+ # This is either done in the .env file, or via an explicit
42
+ # `export AZURE_OPENAI_API_VERSION=...`
43
+ class Config:
44
+ env_prefix = "AZURE_OPENAI_"
32
45
 
33
46
 
34
47
  class AzureGPT(OpenAIGPT):
@@ -41,59 +54,99 @@ class AzureGPT(OpenAIGPT):
41
54
  api_base (str): Azure API base url
42
55
  api_version (str): Azure API version
43
56
  model_name (str): the name of gpt model in your deployment
57
+ model_version (str): the version of gpt model in your deployment
44
58
  """
45
59
 
46
60
  def __init__(self, config: AzureConfig):
61
+ # This will auto-populate config values from .env file
62
+ load_dotenv()
47
63
  super().__init__(config)
48
64
  self.config: AzureConfig = config
49
- self.api_type = config.type
50
- openai.api_type = self.api_type
51
- load_dotenv()
52
- self.api_key = os.getenv("AZURE_API_KEY", "")
53
- if self.api_key == "":
65
+ if self.config.api_key == "":
54
66
  raise ValueError(
55
67
  """
56
- AZURE_API_KEY not set in .env file,
68
+ AZURE_OPENAI_API_KEY not set in .env file,
57
69
  please set it to your Azure API key."""
58
70
  )
59
71
 
60
- self.api_base = os.getenv("AZURE_OPENAI_API_BASE", "")
61
- if self.api_base == "":
72
+ if self.config.api_base == "":
62
73
  raise ValueError(
63
74
  """
64
75
  AZURE_OPENAI_API_BASE not set in .env file,
65
76
  please set it to your Azure API key."""
66
77
  )
67
- # we don't need this for ``api_key`` because it's handled inside
68
- # ``openai_gpt.py`` methods before invoking chat/completion calls
69
- else:
70
- openai.api_base = self.api_base
71
78
 
72
- self.api_version = (
73
- os.getenv("AZURE_OPENAI_API_VERSION", "") or config.api_version
74
- )
75
- openai.api_version = self.api_version
76
-
77
- self.deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "")
78
- if self.deployment_name == "":
79
+ if self.config.deployment_name == "":
79
80
  raise ValueError(
80
81
  """
81
82
  AZURE_OPENAI_DEPLOYMENT_NAME not set in .env file,
82
83
  please set it to your Azure openai deployment name."""
83
84
  )
85
+ self.deployment_name = self.config.deployment_name
84
86
 
85
- self.model_name = os.getenv("AZURE_GPT_MODEL_NAME", "")
86
- if self.model_name == "":
87
+ if self.config.model_name == "":
87
88
  raise ValueError(
88
89
  """
89
- AZURE_GPT_MODEL_NAME not set in .env file,
90
- please set it to chat model name in you deployment model."""
90
+ AZURE_OPENAI_MODEL_NAME not set in .env file,
91
+ please set it to chat model name in your deployment."""
91
92
  )
92
93
 
93
94
  # set the chat model to be the same as the model_name
94
95
  # This corresponds to the gpt model you chose for your deployment
95
96
  # when you deployed a model
96
- if "35-turbo" in self.model_name:
97
+ self.set_chat_model()
98
+
99
+ self.client = AzureOpenAI(
100
+ api_key=self.config.api_key,
101
+ azure_endpoint=self.config.api_base,
102
+ api_version=self.config.api_version,
103
+ azure_deployment=self.config.deployment_name,
104
+ )
105
+ self.async_client = AsyncAzureOpenAI(
106
+ api_key=self.config.api_key,
107
+ azure_endpoint=self.config.api_base,
108
+ api_version=self.config.api_version,
109
+ azure_deployment=self.config.deployment_name,
110
+ timeout=Timeout(self.config.timeout),
111
+ )
112
+
113
+ def set_chat_model(self) -> None:
114
+ """
115
+ Sets the chat model configuration based on the model name specified in the
116
+ ``.env``. This function checks the `model_name` in the configuration and sets
117
+ the appropriate chat model in the `config.chat_model`. It supports handling for
118
+ '35-turbo' and 'gpt-4' models. For 'gpt-4', it further delegates the handling
119
+ to `handle_gpt4_model` method. If the model name does not match any predefined
120
+ models, it defaults to `OpenAIChatModel.GPT4`.
121
+ """
122
+ MODEL_35_TURBO = "35-turbo"
123
+ MODEL_GPT4 = "gpt-4"
124
+
125
+ if self.config.model_name == MODEL_35_TURBO:
97
126
  self.config.chat_model = OpenAIChatModel.GPT3_5_TURBO
127
+ elif self.config.model_name == MODEL_GPT4:
128
+ self.handle_gpt4_model()
129
+ else:
130
+ self.config.chat_model = OpenAIChatModel.GPT4
131
+
132
+ def handle_gpt4_model(self) -> None:
133
+ """
134
+ Handles the setting of the GPT-4 model in the configuration.
135
+ This function checks the `model_version` in the configuration.
136
+ If the version is not set, it raises a ValueError indicating that the model
137
+ version needs to be specified in the ``.env`` file.
138
+ It sets `OpenAIChatModel.GPT4_TURBO` if the version is
139
+ '1106-Preview', otherwise, it defaults to setting `OpenAIChatModel.GPT4`.
140
+ """
141
+ VERSION_1106_PREVIEW = "1106-Preview"
142
+
143
+ if self.config.model_version == "":
144
+ raise ValueError(
145
+ "AZURE_OPENAI_MODEL_VERSION not set in .env file. "
146
+ "Please set it to the chat model version used in your deployment."
147
+ )
148
+
149
+ if self.config.model_version == VERSION_1106_PREVIEW:
150
+ self.config.chat_model = OpenAIChatModel.GPT4_TURBO
98
151
  else:
99
152
  self.config.chat_model = OpenAIChatModel.GPT4
@@ -1,19 +1,20 @@
1
+ import ast
1
2
  import asyncio
2
3
  import json
3
4
  import logging
4
5
  from abc import ABC, abstractmethod
6
+ from datetime import datetime
5
7
  from enum import Enum
6
- from typing import Any, Dict, List, Optional, Tuple, Type, Union
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
7
9
 
8
10
  import aiohttp
9
- from pydantic import BaseModel, BaseSettings
11
+ from pydantic import BaseModel, BaseSettings, Field
10
12
 
11
13
  from langroid.cachedb.momento_cachedb import MomentoCacheConfig
12
14
  from langroid.cachedb.redis_cachedb import RedisCacheConfig
13
- from langroid.language_models.config import Llama2FormatterConfig, PromptFormatterConfig
14
15
  from langroid.mytypes import Document
15
16
  from langroid.parsing.agent_chats import parse_message
16
- from langroid.parsing.json import top_level_json_field
17
+ from langroid.parsing.parse_json import top_level_json_field
17
18
  from langroid.prompts.dialog import collate_chat_history
18
19
  from langroid.prompts.templates import (
19
20
  EXTRACTION_PROMPT_GPT4,
@@ -25,15 +26,21 @@ from langroid.utils.output.printing import show_if_debug
25
26
  logger = logging.getLogger(__name__)
26
27
 
27
28
 
29
+ def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
30
+ pass
31
+
32
+
28
33
  class LLMConfig(BaseSettings):
29
34
  type: str = "openai"
30
- formatter: None | PromptFormatterConfig = Llama2FormatterConfig()
35
+ streamer: Optional[Callable[[Any], None]] = noop_fn
36
+ api_base: str | None = None
37
+ formatter: None | str = None
31
38
  timeout: int = 20 # timeout for API requests
32
39
  chat_model: str = ""
33
40
  completion_model: str = ""
34
41
  temperature: float = 0.0
35
- chat_context_length: int = 1024
36
- completion_context_length: int = 1024
42
+ chat_context_length: int = 8000
43
+ completion_context_length: int = 8000
37
44
  max_output_tokens: int = 1024 # generate at most this many tokens
38
45
  # if input length + max_output_tokens > context length of model,
39
46
  # we will try shortening requested output
@@ -59,6 +66,26 @@ class LLMFunctionCall(BaseModel):
59
66
  to: str = "" # intended recipient
60
67
  arguments: Optional[Dict[str, Any]] = None
61
68
 
69
+ @staticmethod
70
+ def from_dict(message: Dict[str, Any]) -> "LLMFunctionCall":
71
+ """
72
+ Initialize from dictionary.
73
+ Args:
74
+ d: dictionary containing fields to initialize
75
+ """
76
+ fun_call = LLMFunctionCall(name=message["name"])
77
+ fun_args_str = message["arguments"]
78
+ # sometimes may be malformed with invalid indents,
79
+ # so we try to be safe by removing newlines.
80
+ if fun_args_str is not None:
81
+ fun_args_str = fun_args_str.replace("\n", "").strip()
82
+ fun_args = ast.literal_eval(fun_args_str)
83
+ else:
84
+ fun_args = None
85
+ fun_call.arguments = fun_args
86
+
87
+ return fun_call
88
+
62
89
  def __str__(self) -> str:
63
90
  return "FUNC: " + json.dumps(self.dict(), indent=2)
64
91
 
@@ -79,6 +106,20 @@ class LLMTokenUsage(BaseModel):
79
106
  prompt_tokens: int = 0
80
107
  completion_tokens: int = 0
81
108
  cost: float = 0.0
109
+ calls: int = 0 # how many API calls
110
+
111
+ def reset(self) -> None:
112
+ self.prompt_tokens = 0
113
+ self.completion_tokens = 0
114
+ self.cost = 0.0
115
+ self.calls = 0
116
+
117
+ def __str__(self) -> str:
118
+ return (
119
+ f"Tokens = "
120
+ f"(prompt {self.prompt_tokens}, completion {self.completion_tokens}), "
121
+ f"Cost={self.cost}, Calls={self.calls}"
122
+ )
82
123
 
83
124
  @property
84
125
  def total_tokens(self) -> int:
@@ -99,12 +140,16 @@ class LLMMessage(BaseModel):
99
140
 
100
141
  role: Role
101
142
  name: Optional[str] = None
143
+ tool_id: str = "" # used by OpenAIAssistant
102
144
  content: str
103
145
  function_call: Optional[LLMFunctionCall] = None
146
+ timestamp: datetime = Field(default_factory=datetime.utcnow)
104
147
 
105
148
  def api_dict(self) -> Dict[str, Any]:
106
149
  """
107
150
  Convert to dictionary for API request.
151
+ DROP the tool_id, since it is only for use in the Assistant API,
152
+ not the completion API.
108
153
  Returns:
109
154
  dict: dictionary representation of LLM message
110
155
  """
@@ -120,6 +165,8 @@ class LLMMessage(BaseModel):
120
165
  dict_no_none["function_call"]["arguments"] = json.dumps(
121
166
  dict_no_none["function_call"]["arguments"]
122
167
  )
168
+ dict_no_none.pop("tool_id", None)
169
+ dict_no_none.pop("timestamp", None)
123
170
  return dict_no_none
124
171
 
125
172
  def __str__(self) -> str:
@@ -137,10 +184,17 @@ class LLMResponse(BaseModel):
137
184
  """
138
185
 
139
186
  message: str
187
+ tool_id: str = "" # used by OpenAIAssistant
140
188
  function_call: Optional[LLMFunctionCall] = None
141
189
  usage: Optional[LLMTokenUsage]
142
190
  cached: bool = False
143
191
 
192
+ def __str__(self) -> str:
193
+ if self.function_call is not None:
194
+ return str(self.function_call)
195
+ else:
196
+ return self.message
197
+
144
198
  def to_LLMMessage(self) -> LLMMessage:
145
199
  content = self.message
146
200
  role = Role.ASSISTANT if self.function_call is None else Role.FUNCTION
@@ -204,7 +258,10 @@ class LanguageModel(ABC):
204
258
  Abstract base class for language models.
205
259
  """
206
260
 
207
- def __init__(self, config: LLMConfig):
261
+ # usage cost by model, accumulates here
262
+ usage_cost_dict: Dict[str, LLMTokenUsage] = {}
263
+
264
+ def __init__(self, config: LLMConfig = LLMConfig()):
208
265
  self.config = config
209
266
 
210
267
  @staticmethod
@@ -215,6 +272,16 @@ class LanguageModel(ABC):
215
272
  config: configuration for language model
216
273
  Returns: instance of language model
217
274
  """
275
+ if type(config) is LLMConfig:
276
+ raise ValueError(
277
+ """
278
+ Cannot create a Language Model object from LLMConfig.
279
+ Please specify a specific subclass of LLMConfig e.g.,
280
+ OpenAIGPTConfig. If you are creating a ChatAgent from
281
+ a ChatAgentConfig, please specify the `llm` field of this config
282
+ as a specific subclass of LLMConfig, e.g., OpenAIGPTConfig.
283
+ """
284
+ )
218
285
  from langroid.language_models.azure_openai import AzureGPT
219
286
  from langroid.language_models.openai_gpt import OpenAIGPT
220
287
 
@@ -311,18 +378,18 @@ class LanguageModel(ABC):
311
378
  pass
312
379
 
313
380
  @abstractmethod
314
- def generate(self, prompt: str, max_tokens: int) -> LLMResponse:
381
+ def generate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
315
382
  pass
316
383
 
317
384
  @abstractmethod
318
- async def agenerate(self, prompt: str, max_tokens: int) -> LLMResponse:
385
+ async def agenerate(self, prompt: str, max_tokens: int = 200) -> LLMResponse:
319
386
  pass
320
387
 
321
388
  @abstractmethod
322
389
  def chat(
323
390
  self,
324
391
  messages: Union[str, List[LLMMessage]],
325
- max_tokens: int,
392
+ max_tokens: int = 200,
326
393
  functions: Optional[List[LLMFunctionSpec]] = None,
327
394
  function_call: str | Dict[str, str] = "auto",
328
395
  ) -> LLMResponse:
@@ -332,7 +399,7 @@ class LanguageModel(ABC):
332
399
  async def achat(
333
400
  self,
334
401
  messages: Union[str, List[LLMMessage]],
335
- max_tokens: int,
402
+ max_tokens: int = 200,
336
403
  functions: Optional[List[LLMFunctionSpec]] = None,
337
404
  function_call: str | Dict[str, str] = "auto",
338
405
  ) -> LLMResponse:
@@ -350,6 +417,44 @@ class LanguageModel(ABC):
350
417
  def chat_cost(self) -> Tuple[float, float]:
351
418
  return self.config.chat_cost_per_1k_tokens
352
419
 
420
+ def reset_usage_cost(self) -> None:
421
+ for mdl in [self.config.chat_model, self.config.completion_model]:
422
+ if mdl is None:
423
+ return
424
+ if mdl not in self.usage_cost_dict:
425
+ self.usage_cost_dict[mdl] = LLMTokenUsage()
426
+ counter = self.usage_cost_dict[mdl]
427
+ counter.reset()
428
+
429
+ def update_usage_cost(
430
+ self, chat: bool, prompts: int, completions: int, cost: float
431
+ ) -> None:
432
+ """
433
+ Update usage cost for this LLM.
434
+ Args:
435
+ chat (bool): whether to update for chat or completion model
436
+ prompts (int): number of tokens used for prompts
437
+ completions (int): number of tokens used for completions
438
+ cost (float): total token cost in USD
439
+ """
440
+ mdl = self.config.chat_model if chat else self.config.completion_model
441
+ if mdl is None:
442
+ return
443
+ if mdl not in self.usage_cost_dict:
444
+ self.usage_cost_dict[mdl] = LLMTokenUsage()
445
+ counter = self.usage_cost_dict[mdl]
446
+ counter.prompt_tokens += prompts
447
+ counter.completion_tokens += completions
448
+ counter.cost += cost
449
+ counter.calls += 1
450
+
451
+ @classmethod
452
+ def usage_cost_summary(cls) -> str:
453
+ s = ""
454
+ for model, counter in cls.usage_cost_dict.items():
455
+ s += f"{model}: {counter}\n"
456
+ return s
457
+
353
458
  def followup_to_standalone(
354
459
  self, chat_history: List[Tuple[str, str]], question: str
355
460
  ) -> str:
@@ -11,3 +11,8 @@ class PromptFormatterConfig(BaseSettings):
11
11
 
12
12
  class Llama2FormatterConfig(PromptFormatterConfig):
13
13
  use_bos_eos: bool = False
14
+
15
+
16
+ class HFPromptFormatterConfig(PromptFormatterConfig):
17
+ type: str = "hf"
18
+ model_name: str
@@ -0,0 +1,3 @@
1
+ import openai
2
+
3
+ openai.models.list()