jl-ecms-client 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jl-ecms-client might be problematic. Click here for more details.
- jl_ecms_client-0.2.8.dist-info/METADATA +295 -0
- jl_ecms_client-0.2.8.dist-info/RECORD +53 -0
- jl_ecms_client-0.2.8.dist-info/WHEEL +5 -0
- jl_ecms_client-0.2.8.dist-info/licenses/LICENSE +190 -0
- jl_ecms_client-0.2.8.dist-info/top_level.txt +1 -0
- mirix/client/__init__.py +14 -0
- mirix/client/client.py +405 -0
- mirix/client/constants.py +60 -0
- mirix/client/remote_client.py +1136 -0
- mirix/client/utils.py +34 -0
- mirix/helpers/__init__.py +1 -0
- mirix/helpers/converters.py +429 -0
- mirix/helpers/datetime_helpers.py +90 -0
- mirix/helpers/json_helpers.py +47 -0
- mirix/helpers/message_helpers.py +74 -0
- mirix/helpers/tool_rule_solver.py +166 -0
- mirix/schemas/__init__.py +1 -0
- mirix/schemas/agent.py +401 -0
- mirix/schemas/block.py +188 -0
- mirix/schemas/cloud_file_mapping.py +29 -0
- mirix/schemas/embedding_config.py +114 -0
- mirix/schemas/enums.py +69 -0
- mirix/schemas/environment_variables.py +82 -0
- mirix/schemas/episodic_memory.py +170 -0
- mirix/schemas/file.py +57 -0
- mirix/schemas/health.py +10 -0
- mirix/schemas/knowledge_vault.py +181 -0
- mirix/schemas/llm_config.py +187 -0
- mirix/schemas/memory.py +318 -0
- mirix/schemas/message.py +1315 -0
- mirix/schemas/mirix_base.py +107 -0
- mirix/schemas/mirix_message.py +411 -0
- mirix/schemas/mirix_message_content.py +230 -0
- mirix/schemas/mirix_request.py +39 -0
- mirix/schemas/mirix_response.py +183 -0
- mirix/schemas/openai/__init__.py +1 -0
- mirix/schemas/openai/chat_completion_request.py +122 -0
- mirix/schemas/openai/chat_completion_response.py +144 -0
- mirix/schemas/openai/chat_completions.py +127 -0
- mirix/schemas/openai/embedding_response.py +11 -0
- mirix/schemas/openai/openai.py +229 -0
- mirix/schemas/organization.py +38 -0
- mirix/schemas/procedural_memory.py +151 -0
- mirix/schemas/providers.py +816 -0
- mirix/schemas/resource_memory.py +134 -0
- mirix/schemas/sandbox_config.py +132 -0
- mirix/schemas/semantic_memory.py +162 -0
- mirix/schemas/source.py +96 -0
- mirix/schemas/step.py +53 -0
- mirix/schemas/tool.py +241 -0
- mirix/schemas/tool_rule.py +209 -0
- mirix/schemas/usage.py +31 -0
- mirix/schemas/user.py +67 -0
|
@@ -0,0 +1,816 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, model_validator
|
|
5
|
+
|
|
6
|
+
from mirix.client.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
|
7
|
+
from mirix.llm_api.azure_openai import (
|
|
8
|
+
get_azure_chat_completions_endpoint,
|
|
9
|
+
get_azure_embeddings_endpoint,
|
|
10
|
+
)
|
|
11
|
+
from mirix.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
|
|
12
|
+
from mirix.log import get_logger
|
|
13
|
+
from mirix.schemas.embedding_config import EmbeddingConfig
|
|
14
|
+
from mirix.schemas.llm_config import LLMConfig
|
|
15
|
+
from mirix.schemas.mirix_base import MirixBase
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ProviderBase(MirixBase):
|
|
21
|
+
__id_prefix__ = "provider"
|
|
22
|
+
|
|
23
|
+
class Provider(ProviderBase):
|
|
24
|
+
id: Optional[str] = Field(
|
|
25
|
+
None,
|
|
26
|
+
description="The id of the provider, lazily created by the database manager.",
|
|
27
|
+
)
|
|
28
|
+
name: str = Field(..., description="The name of the provider")
|
|
29
|
+
api_key: Optional[str] = Field(
|
|
30
|
+
None, description="API key used for requests to the provider."
|
|
31
|
+
)
|
|
32
|
+
organization_id: Optional[str] = Field(
|
|
33
|
+
None, description="The organization id of the user"
|
|
34
|
+
)
|
|
35
|
+
updated_at: Optional[datetime] = Field(
|
|
36
|
+
None, description="The last update timestamp of the provider."
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def resolve_identifier(self):
|
|
40
|
+
if not self.id:
|
|
41
|
+
self.id = ProviderBase._generate_id(prefix=ProviderBase.__id_prefix__)
|
|
42
|
+
|
|
43
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
44
|
+
return []
|
|
45
|
+
|
|
46
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
47
|
+
return []
|
|
48
|
+
|
|
49
|
+
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
50
|
+
raise NotImplementedError
|
|
51
|
+
|
|
52
|
+
def provider_tag(self) -> str:
|
|
53
|
+
"""String representation of the provider for display purposes"""
|
|
54
|
+
raise NotImplementedError
|
|
55
|
+
|
|
56
|
+
def get_handle(self, model_name: str) -> str:
|
|
57
|
+
return f"{self.name}/{model_name}"
|
|
58
|
+
|
|
59
|
+
class ProviderCreate(ProviderBase):
|
|
60
|
+
name: str = Field(..., description="The name of the provider.")
|
|
61
|
+
api_key: str = Field(..., description="API key used for requests to the provider.")
|
|
62
|
+
|
|
63
|
+
class ProviderUpdate(ProviderBase):
|
|
64
|
+
id: str = Field(..., description="The id of the provider to update.")
|
|
65
|
+
api_key: str = Field(..., description="API key used for requests to the provider.")
|
|
66
|
+
|
|
67
|
+
class MirixProvider(Provider):
|
|
68
|
+
name: str = "mirix"
|
|
69
|
+
|
|
70
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
71
|
+
return [
|
|
72
|
+
LLMConfig(
|
|
73
|
+
model="mirix-free", # NOTE: renamed
|
|
74
|
+
model_endpoint_type="openai",
|
|
75
|
+
model_endpoint="https://inference.memgpt.ai",
|
|
76
|
+
context_window=8192,
|
|
77
|
+
handle=self.get_handle("mirix-free"),
|
|
78
|
+
)
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
def list_embedding_models(self):
|
|
82
|
+
return [
|
|
83
|
+
EmbeddingConfig(
|
|
84
|
+
embedding_model="mirix-free", # NOTE: renamed
|
|
85
|
+
embedding_endpoint_type="hugging-face",
|
|
86
|
+
embedding_endpoint="https://embeddings.memgpt.ai",
|
|
87
|
+
embedding_dim=1024,
|
|
88
|
+
embedding_chunk_size=300,
|
|
89
|
+
handle=self.get_handle("mirix-free"),
|
|
90
|
+
)
|
|
91
|
+
]
|
|
92
|
+
|
|
93
|
+
class OpenAIProvider(Provider):
|
|
94
|
+
name: str = "openai"
|
|
95
|
+
api_key: str = Field(..., description="API key for the OpenAI API.")
|
|
96
|
+
base_url: str = Field(..., description="Base URL for the OpenAI API.")
|
|
97
|
+
|
|
98
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
99
|
+
from mirix.llm_api.openai import openai_get_model_list
|
|
100
|
+
|
|
101
|
+
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
|
102
|
+
# See: https://openrouter.ai/docs/requests
|
|
103
|
+
extra_params = (
|
|
104
|
+
{"supported_parameters": "tools"}
|
|
105
|
+
if "openrouter.ai" in self.base_url
|
|
106
|
+
else None
|
|
107
|
+
)
|
|
108
|
+
response = openai_get_model_list(
|
|
109
|
+
self.base_url, api_key=self.api_key, extra_params=extra_params
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# TogetherAI's response is missing the 'data' field
|
|
113
|
+
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
|
114
|
+
if "data" in response:
|
|
115
|
+
data = response["data"]
|
|
116
|
+
else:
|
|
117
|
+
data = response
|
|
118
|
+
|
|
119
|
+
configs = []
|
|
120
|
+
for model in data:
|
|
121
|
+
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
|
|
122
|
+
model_name = model["id"]
|
|
123
|
+
|
|
124
|
+
if "context_length" in model:
|
|
125
|
+
# Context length is returned in OpenRouter as "context_length"
|
|
126
|
+
context_window_size = model["context_length"]
|
|
127
|
+
else:
|
|
128
|
+
context_window_size = self.get_model_context_window_size(model_name)
|
|
129
|
+
|
|
130
|
+
if not context_window_size:
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
# TogetherAI includes the type, which we can use to filter out embedding models
|
|
134
|
+
if self.base_url == "https://api.together.ai/v1":
|
|
135
|
+
if "type" in model and model["type"] != "chat":
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
# for TogetherAI, we need to skip the models that don't support JSON mode / function calling
|
|
139
|
+
# requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: {
|
|
140
|
+
# "error": {
|
|
141
|
+
# "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling",
|
|
142
|
+
# "type": "invalid_request_error",
|
|
143
|
+
# "param": null,
|
|
144
|
+
# "code": "constraints_model"
|
|
145
|
+
# }
|
|
146
|
+
# }
|
|
147
|
+
if "config" not in model:
|
|
148
|
+
continue
|
|
149
|
+
if "chat_template" not in model["config"]:
|
|
150
|
+
continue
|
|
151
|
+
if model["config"]["chat_template"] is None:
|
|
152
|
+
continue
|
|
153
|
+
if "tools" not in model["config"]["chat_template"]:
|
|
154
|
+
continue
|
|
155
|
+
# if "config" in data and "chat_template" in data["config"] and "tools" not in data["config"]["chat_template"]:
|
|
156
|
+
# continue
|
|
157
|
+
|
|
158
|
+
configs.append(
|
|
159
|
+
LLMConfig(
|
|
160
|
+
model=model_name,
|
|
161
|
+
model_endpoint_type="openai",
|
|
162
|
+
model_endpoint=self.base_url,
|
|
163
|
+
context_window=context_window_size,
|
|
164
|
+
handle=self.get_handle(model_name),
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# for OpenAI, sort in reverse order
|
|
169
|
+
if self.base_url == "https://api.openai.com/v1":
|
|
170
|
+
# alphnumeric sort
|
|
171
|
+
configs.sort(key=lambda x: x.model, reverse=True)
|
|
172
|
+
|
|
173
|
+
return configs
|
|
174
|
+
|
|
175
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
176
|
+
# TODO: actually automatically list models
|
|
177
|
+
return [
|
|
178
|
+
EmbeddingConfig(
|
|
179
|
+
embedding_model="text-embedding-3-small",
|
|
180
|
+
embedding_endpoint_type="openai",
|
|
181
|
+
embedding_endpoint="https://api.openai.com/v1",
|
|
182
|
+
embedding_dim=1536,
|
|
183
|
+
embedding_chunk_size=300,
|
|
184
|
+
handle=self.get_handle("text-embedding-3-small"),
|
|
185
|
+
),
|
|
186
|
+
EmbeddingConfig(
|
|
187
|
+
embedding_model="text-embedding-3-small",
|
|
188
|
+
embedding_endpoint_type="openai",
|
|
189
|
+
embedding_endpoint="https://api.openai.com/v1",
|
|
190
|
+
embedding_dim=2000,
|
|
191
|
+
embedding_chunk_size=300,
|
|
192
|
+
handle=self.get_handle("text-embedding-3-small"),
|
|
193
|
+
),
|
|
194
|
+
EmbeddingConfig(
|
|
195
|
+
embedding_model="text-embedding-3-large",
|
|
196
|
+
embedding_endpoint_type="openai",
|
|
197
|
+
embedding_endpoint="https://api.openai.com/v1",
|
|
198
|
+
embedding_dim=2000,
|
|
199
|
+
embedding_chunk_size=300,
|
|
200
|
+
handle=self.get_handle("text-embedding-3-large"),
|
|
201
|
+
),
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
def get_model_context_window_size(self, model_name: str):
|
|
205
|
+
if model_name in LLM_MAX_TOKENS:
|
|
206
|
+
return LLM_MAX_TOKENS[model_name]
|
|
207
|
+
else:
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
class AnthropicProvider(Provider):
|
|
211
|
+
name: str = "anthropic"
|
|
212
|
+
api_key: str = Field(..., description="API key for the Anthropic API.")
|
|
213
|
+
base_url: str = "https://api.anthropic.com/v1"
|
|
214
|
+
|
|
215
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
216
|
+
from mirix.llm_api.anthropic import anthropic_get_model_list
|
|
217
|
+
|
|
218
|
+
models = anthropic_get_model_list(self.base_url, api_key=self.api_key)
|
|
219
|
+
|
|
220
|
+
configs = []
|
|
221
|
+
for model in models:
|
|
222
|
+
configs.append(
|
|
223
|
+
LLMConfig(
|
|
224
|
+
model=model["name"],
|
|
225
|
+
model_endpoint_type="anthropic",
|
|
226
|
+
model_endpoint=self.base_url,
|
|
227
|
+
context_window=model["context_window"],
|
|
228
|
+
handle=self.get_handle(model["name"]),
|
|
229
|
+
)
|
|
230
|
+
)
|
|
231
|
+
return configs
|
|
232
|
+
|
|
233
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
234
|
+
return []
|
|
235
|
+
|
|
236
|
+
class MistralProvider(Provider):
|
|
237
|
+
name: str = "mistral"
|
|
238
|
+
api_key: str = Field(..., description="API key for the Mistral API.")
|
|
239
|
+
base_url: str = "https://api.mistral.ai/v1"
|
|
240
|
+
|
|
241
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
242
|
+
from mirix.llm_api.mistral import mistral_get_model_list
|
|
243
|
+
|
|
244
|
+
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
|
245
|
+
# See: https://openrouter.ai/docs/requests
|
|
246
|
+
response = mistral_get_model_list(self.base_url, api_key=self.api_key)
|
|
247
|
+
|
|
248
|
+
assert "data" in response, (
|
|
249
|
+
f"Mistral model query response missing 'data' field: {response}"
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
configs = []
|
|
253
|
+
for model in response["data"]:
|
|
254
|
+
# If model has chat completions and function calling enabled
|
|
255
|
+
if (
|
|
256
|
+
model["capabilities"]["completion_chat"]
|
|
257
|
+
and model["capabilities"]["function_calling"]
|
|
258
|
+
):
|
|
259
|
+
configs.append(
|
|
260
|
+
LLMConfig(
|
|
261
|
+
model=model["id"],
|
|
262
|
+
model_endpoint_type="openai",
|
|
263
|
+
model_endpoint=self.base_url,
|
|
264
|
+
context_window=model["max_context_length"],
|
|
265
|
+
handle=self.get_handle(model["id"]),
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
return configs
|
|
270
|
+
|
|
271
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
272
|
+
# Not supported for mistral
|
|
273
|
+
return []
|
|
274
|
+
|
|
275
|
+
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
276
|
+
# Redoing this is fine because it's a pretty lightweight call
|
|
277
|
+
models = self.list_llm_models()
|
|
278
|
+
|
|
279
|
+
for m in models:
|
|
280
|
+
if model_name in m["id"]:
|
|
281
|
+
return int(m["max_context_length"])
|
|
282
|
+
|
|
283
|
+
return None
|
|
284
|
+
|
|
285
|
+
class OllamaProvider(OpenAIProvider):
|
|
286
|
+
"""Ollama provider that uses the native /api/generate endpoint
|
|
287
|
+
|
|
288
|
+
See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
name: str = "ollama"
|
|
292
|
+
base_url: str = Field(..., description="Base URL for the Ollama API.")
|
|
293
|
+
api_key: Optional[str] = Field(
|
|
294
|
+
None, description="API key for the Ollama API (default: `None`)."
|
|
295
|
+
)
|
|
296
|
+
default_prompt_formatter: str = Field(
|
|
297
|
+
...,
|
|
298
|
+
description="Default prompt formatter (aka model wrapper) to use on a /completions style API.",
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
302
|
+
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
|
303
|
+
import requests
|
|
304
|
+
|
|
305
|
+
response = requests.get(f"{self.base_url}/api/tags")
|
|
306
|
+
if response.status_code != 200:
|
|
307
|
+
raise Exception(f"Failed to list Ollama models: {response.text}")
|
|
308
|
+
response_json = response.json()
|
|
309
|
+
|
|
310
|
+
configs = []
|
|
311
|
+
for model in response_json["models"]:
|
|
312
|
+
context_window = self.get_model_context_window(model["name"])
|
|
313
|
+
if context_window is None:
|
|
314
|
+
logger.debug("Ollama model %s has no context window", model['name'])
|
|
315
|
+
continue
|
|
316
|
+
configs.append(
|
|
317
|
+
LLMConfig(
|
|
318
|
+
model=model["name"],
|
|
319
|
+
model_endpoint_type="ollama",
|
|
320
|
+
model_endpoint=self.base_url,
|
|
321
|
+
model_wrapper=self.default_prompt_formatter,
|
|
322
|
+
context_window=context_window,
|
|
323
|
+
handle=self.get_handle(model["name"]),
|
|
324
|
+
)
|
|
325
|
+
)
|
|
326
|
+
return configs
|
|
327
|
+
|
|
328
|
+
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
329
|
+
import requests
|
|
330
|
+
|
|
331
|
+
response = requests.post(
|
|
332
|
+
f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}
|
|
333
|
+
)
|
|
334
|
+
response_json = response.json()
|
|
335
|
+
|
|
336
|
+
## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
|
|
337
|
+
# possible_keys = [
|
|
338
|
+
# # OPT
|
|
339
|
+
# "max_position_embeddings",
|
|
340
|
+
# # GPT-2
|
|
341
|
+
# "n_positions",
|
|
342
|
+
# # MPT
|
|
343
|
+
# "max_seq_len",
|
|
344
|
+
# # ChatGLM2
|
|
345
|
+
# "seq_length",
|
|
346
|
+
# # Command-R
|
|
347
|
+
# "model_max_length",
|
|
348
|
+
# # Others
|
|
349
|
+
# "max_sequence_length",
|
|
350
|
+
# "max_seq_length",
|
|
351
|
+
# "seq_len",
|
|
352
|
+
# ]
|
|
353
|
+
# max_position_embeddings
|
|
354
|
+
# parse model cards: nous, dolphon, llama
|
|
355
|
+
if "model_info" not in response_json:
|
|
356
|
+
if "error" in response_json:
|
|
357
|
+
logger.error(
|
|
358
|
+
f"Ollama fetch model info error for {model_name}: {response_json['error']}"
|
|
359
|
+
)
|
|
360
|
+
return None
|
|
361
|
+
for key, value in response_json["model_info"].items():
|
|
362
|
+
if "context_length" in key:
|
|
363
|
+
return value
|
|
364
|
+
return None
|
|
365
|
+
|
|
366
|
+
def get_model_embedding_dim(self, model_name: str):
|
|
367
|
+
import requests
|
|
368
|
+
|
|
369
|
+
response = requests.post(
|
|
370
|
+
f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}
|
|
371
|
+
)
|
|
372
|
+
response_json = response.json()
|
|
373
|
+
if "model_info" not in response_json:
|
|
374
|
+
if "error" in response_json:
|
|
375
|
+
logger.error(
|
|
376
|
+
f"Ollama fetch model info error for {model_name}: {response_json['error']}"
|
|
377
|
+
)
|
|
378
|
+
return None
|
|
379
|
+
for key, value in response_json["model_info"].items():
|
|
380
|
+
if "embedding_length" in key:
|
|
381
|
+
return value
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
385
|
+
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
|
386
|
+
import requests
|
|
387
|
+
|
|
388
|
+
response = requests.get(f"{self.base_url}/api/tags")
|
|
389
|
+
if response.status_code != 200:
|
|
390
|
+
raise Exception(f"Failed to list Ollama models: {response.text}")
|
|
391
|
+
response_json = response.json()
|
|
392
|
+
|
|
393
|
+
configs = []
|
|
394
|
+
for model in response_json["models"]:
|
|
395
|
+
embedding_dim = self.get_model_embedding_dim(model["name"])
|
|
396
|
+
if not embedding_dim:
|
|
397
|
+
logger.debug("Ollama model %s has no embedding dimension", model['name'])
|
|
398
|
+
continue
|
|
399
|
+
configs.append(
|
|
400
|
+
EmbeddingConfig(
|
|
401
|
+
embedding_model=model["name"],
|
|
402
|
+
embedding_endpoint_type="ollama",
|
|
403
|
+
embedding_endpoint=self.base_url,
|
|
404
|
+
embedding_dim=embedding_dim,
|
|
405
|
+
embedding_chunk_size=300,
|
|
406
|
+
handle=self.get_handle(model["name"]),
|
|
407
|
+
)
|
|
408
|
+
)
|
|
409
|
+
return configs
|
|
410
|
+
|
|
411
|
+
class GroqProvider(OpenAIProvider):
|
|
412
|
+
name: str = "groq"
|
|
413
|
+
base_url: str = "https://api.groq.com/openai/v1"
|
|
414
|
+
api_key: str = Field(..., description="API key for the Groq API.")
|
|
415
|
+
|
|
416
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
417
|
+
from mirix.llm_api.openai import openai_get_model_list
|
|
418
|
+
|
|
419
|
+
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
|
420
|
+
configs = []
|
|
421
|
+
for model in response["data"]:
|
|
422
|
+
if "context_window" not in model:
|
|
423
|
+
continue
|
|
424
|
+
configs.append(
|
|
425
|
+
LLMConfig(
|
|
426
|
+
model=model["id"],
|
|
427
|
+
model_endpoint_type="groq",
|
|
428
|
+
model_endpoint=self.base_url,
|
|
429
|
+
context_window=model["context_window"],
|
|
430
|
+
handle=self.get_handle(model["id"]),
|
|
431
|
+
)
|
|
432
|
+
)
|
|
433
|
+
return configs
|
|
434
|
+
|
|
435
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
436
|
+
return []
|
|
437
|
+
|
|
438
|
+
def get_model_context_window_size(self, model_name: str):
|
|
439
|
+
raise NotImplementedError
|
|
440
|
+
|
|
441
|
+
class TogetherProvider(OpenAIProvider):
|
|
442
|
+
"""TogetherAI provider that uses the /completions API
|
|
443
|
+
|
|
444
|
+
TogetherAI can also be used via the /chat/completions API
|
|
445
|
+
by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key
|
|
446
|
+
and API URL, however /completions is preferred because their /chat/completions
|
|
447
|
+
function calling support is limited.
|
|
448
|
+
"""
|
|
449
|
+
|
|
450
|
+
name: str = "together"
|
|
451
|
+
base_url: str = "https://api.together.ai/v1"
|
|
452
|
+
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
|
453
|
+
default_prompt_formatter: str = Field(
|
|
454
|
+
...,
|
|
455
|
+
description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.",
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
459
|
+
from mirix.llm_api.openai import openai_get_model_list
|
|
460
|
+
|
|
461
|
+
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
|
462
|
+
|
|
463
|
+
# TogetherAI's response is missing the 'data' field
|
|
464
|
+
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
|
465
|
+
if "data" in response:
|
|
466
|
+
data = response["data"]
|
|
467
|
+
else:
|
|
468
|
+
data = response
|
|
469
|
+
|
|
470
|
+
configs = []
|
|
471
|
+
for model in data:
|
|
472
|
+
assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
|
|
473
|
+
model_name = model["id"]
|
|
474
|
+
|
|
475
|
+
if "context_length" in model:
|
|
476
|
+
# Context length is returned in OpenRouter as "context_length"
|
|
477
|
+
context_window_size = model["context_length"]
|
|
478
|
+
else:
|
|
479
|
+
context_window_size = self.get_model_context_window_size(model_name)
|
|
480
|
+
|
|
481
|
+
# We need the context length for embeddings too
|
|
482
|
+
if not context_window_size:
|
|
483
|
+
continue
|
|
484
|
+
|
|
485
|
+
# Skip models that are too small for Mirix
|
|
486
|
+
if context_window_size <= MIN_CONTEXT_WINDOW:
|
|
487
|
+
continue
|
|
488
|
+
|
|
489
|
+
# TogetherAI includes the type, which we can use to filter for embedding models
|
|
490
|
+
if "type" in model and model["type"] not in ["chat", "language"]:
|
|
491
|
+
continue
|
|
492
|
+
|
|
493
|
+
configs.append(
|
|
494
|
+
LLMConfig(
|
|
495
|
+
model=model_name,
|
|
496
|
+
model_endpoint_type="together",
|
|
497
|
+
model_endpoint=self.base_url,
|
|
498
|
+
model_wrapper=self.default_prompt_formatter,
|
|
499
|
+
context_window=context_window_size,
|
|
500
|
+
handle=self.get_handle(model_name),
|
|
501
|
+
)
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
return configs
|
|
505
|
+
|
|
506
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
507
|
+
# TODO renable once we figure out how to pass API keys through properly
|
|
508
|
+
return []
|
|
509
|
+
|
|
510
|
+
# from mirix.llm_api.openai import openai_get_model_list
|
|
511
|
+
|
|
512
|
+
# response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
|
513
|
+
|
|
514
|
+
# # TogetherAI's response is missing the 'data' field
|
|
515
|
+
# # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
|
516
|
+
# if "data" in response:
|
|
517
|
+
# data = response["data"]
|
|
518
|
+
# else:
|
|
519
|
+
# data = response
|
|
520
|
+
|
|
521
|
+
# configs = []
|
|
522
|
+
# for model in data:
|
|
523
|
+
# assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
|
|
524
|
+
# model_name = model["id"]
|
|
525
|
+
|
|
526
|
+
# if "context_length" in model:
|
|
527
|
+
# # Context length is returned in OpenRouter as "context_length"
|
|
528
|
+
# context_window_size = model["context_length"]
|
|
529
|
+
# else:
|
|
530
|
+
# context_window_size = self.get_model_context_window_size(model_name)
|
|
531
|
+
|
|
532
|
+
# if not context_window_size:
|
|
533
|
+
# continue
|
|
534
|
+
|
|
535
|
+
# # TogetherAI includes the type, which we can use to filter out embedding models
|
|
536
|
+
# if "type" in model and model["type"] not in ["embedding"]:
|
|
537
|
+
# continue
|
|
538
|
+
|
|
539
|
+
# configs.append(
|
|
540
|
+
# EmbeddingConfig(
|
|
541
|
+
# embedding_model=model_name,
|
|
542
|
+
# embedding_endpoint_type="openai",
|
|
543
|
+
# embedding_endpoint=self.base_url,
|
|
544
|
+
# embedding_dim=context_window_size,
|
|
545
|
+
# embedding_chunk_size=300, # TODO: change?
|
|
546
|
+
# )
|
|
547
|
+
# )
|
|
548
|
+
|
|
549
|
+
# return configs
|
|
550
|
+
|
|
551
|
+
class GoogleAIProvider(Provider):
|
|
552
|
+
# gemini
|
|
553
|
+
name: str = "google_ai"
|
|
554
|
+
api_key: str = Field(..., description="API key for the Google AI API.")
|
|
555
|
+
base_url: str = "https://generativelanguage.googleapis.com"
|
|
556
|
+
|
|
557
|
+
def list_llm_models(self):
|
|
558
|
+
from mirix.llm_api.google_ai import google_ai_get_model_list
|
|
559
|
+
|
|
560
|
+
model_options = google_ai_get_model_list(
|
|
561
|
+
base_url=self.base_url, api_key=self.api_key
|
|
562
|
+
)
|
|
563
|
+
# filter by 'generateContent' models
|
|
564
|
+
model_options = [
|
|
565
|
+
mo
|
|
566
|
+
for mo in model_options
|
|
567
|
+
if "generateContent" in mo["supportedGenerationMethods"]
|
|
568
|
+
]
|
|
569
|
+
model_options = [str(m["name"]) for m in model_options]
|
|
570
|
+
|
|
571
|
+
# filter by model names
|
|
572
|
+
model_options = [
|
|
573
|
+
mo[len("models/") :] if mo.startswith("models/") else mo
|
|
574
|
+
for mo in model_options
|
|
575
|
+
]
|
|
576
|
+
|
|
577
|
+
# TODO remove manual filtering for gemini-pro
|
|
578
|
+
# Add support for all gemini models
|
|
579
|
+
model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
|
|
580
|
+
|
|
581
|
+
configs = []
|
|
582
|
+
for model in model_options:
|
|
583
|
+
configs.append(
|
|
584
|
+
LLMConfig(
|
|
585
|
+
model=model,
|
|
586
|
+
model_endpoint_type="google_ai",
|
|
587
|
+
model_endpoint=self.base_url,
|
|
588
|
+
context_window=self.get_model_context_window(model),
|
|
589
|
+
handle=self.get_handle(model),
|
|
590
|
+
)
|
|
591
|
+
)
|
|
592
|
+
return configs
|
|
593
|
+
|
|
594
|
+
def list_embedding_models(self):
|
|
595
|
+
from mirix.llm_api.google_ai import google_ai_get_model_list
|
|
596
|
+
|
|
597
|
+
# TODO: use base_url instead
|
|
598
|
+
model_options = google_ai_get_model_list(
|
|
599
|
+
base_url=self.base_url, api_key=self.api_key
|
|
600
|
+
)
|
|
601
|
+
# filter by 'generateContent' models
|
|
602
|
+
model_options = [
|
|
603
|
+
mo
|
|
604
|
+
for mo in model_options
|
|
605
|
+
if "embedContent" in mo["supportedGenerationMethods"]
|
|
606
|
+
]
|
|
607
|
+
model_options = [str(m["name"]) for m in model_options]
|
|
608
|
+
model_options = [
|
|
609
|
+
mo[len("models/") :] if mo.startswith("models/") else mo
|
|
610
|
+
for mo in model_options
|
|
611
|
+
]
|
|
612
|
+
|
|
613
|
+
configs = []
|
|
614
|
+
for model in model_options:
|
|
615
|
+
configs.append(
|
|
616
|
+
EmbeddingConfig(
|
|
617
|
+
embedding_model=model,
|
|
618
|
+
embedding_endpoint_type="google_ai",
|
|
619
|
+
embedding_endpoint=self.base_url,
|
|
620
|
+
embedding_dim=768,
|
|
621
|
+
embedding_chunk_size=300, # NOTE: max is 2048
|
|
622
|
+
handle=self.get_handle(model),
|
|
623
|
+
)
|
|
624
|
+
)
|
|
625
|
+
return configs
|
|
626
|
+
|
|
627
|
+
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
628
|
+
from mirix.llm_api.google_ai import google_ai_get_model_context_window
|
|
629
|
+
|
|
630
|
+
return google_ai_get_model_context_window(
|
|
631
|
+
self.base_url, self.api_key, model_name
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
class AzureProvider(Provider):
|
|
635
|
+
name: str = "azure"
|
|
636
|
+
latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
|
|
637
|
+
base_url: str = Field(
|
|
638
|
+
...,
|
|
639
|
+
description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://mirix.openai.azure.com`.",
|
|
640
|
+
)
|
|
641
|
+
api_key: str = Field(..., description="API key for the Azure API.")
|
|
642
|
+
api_version: str = Field(
|
|
643
|
+
latest_api_version, description="API version for the Azure API"
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
@model_validator(mode="before")
|
|
647
|
+
def set_default_api_version(cls, values):
|
|
648
|
+
"""
|
|
649
|
+
This ensures that api_version is always set to the default if None is passed in.
|
|
650
|
+
"""
|
|
651
|
+
if values.get("api_version") is None:
|
|
652
|
+
values["api_version"] = cls.model_fields["latest_api_version"].default
|
|
653
|
+
return values
|
|
654
|
+
|
|
655
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
656
|
+
from mirix.llm_api.azure_openai import (
|
|
657
|
+
azure_openai_get_chat_completion_model_list,
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
model_options = azure_openai_get_chat_completion_model_list(
|
|
661
|
+
self.base_url, api_key=self.api_key, api_version=self.api_version
|
|
662
|
+
)
|
|
663
|
+
configs = []
|
|
664
|
+
for model_option in model_options:
|
|
665
|
+
model_name = model_option["id"]
|
|
666
|
+
context_window_size = self.get_model_context_window(model_name)
|
|
667
|
+
model_endpoint = get_azure_chat_completions_endpoint(
|
|
668
|
+
self.base_url, model_name, self.api_version
|
|
669
|
+
)
|
|
670
|
+
configs.append(
|
|
671
|
+
LLMConfig(
|
|
672
|
+
model=model_name,
|
|
673
|
+
model_endpoint_type="azure",
|
|
674
|
+
model_endpoint=model_endpoint,
|
|
675
|
+
context_window=context_window_size,
|
|
676
|
+
handle=self.get_handle(model_name),
|
|
677
|
+
),
|
|
678
|
+
)
|
|
679
|
+
return configs
|
|
680
|
+
|
|
681
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
682
|
+
from mirix.llm_api.azure_openai import azure_openai_get_embeddings_model_list
|
|
683
|
+
|
|
684
|
+
model_options = azure_openai_get_embeddings_model_list(
|
|
685
|
+
self.base_url,
|
|
686
|
+
api_key=self.api_key,
|
|
687
|
+
api_version=self.api_version,
|
|
688
|
+
require_embedding_in_name=True,
|
|
689
|
+
)
|
|
690
|
+
configs = []
|
|
691
|
+
for model_option in model_options:
|
|
692
|
+
model_name = model_option["id"]
|
|
693
|
+
model_endpoint = get_azure_embeddings_endpoint(
|
|
694
|
+
self.base_url, model_name, self.api_version
|
|
695
|
+
)
|
|
696
|
+
configs.append(
|
|
697
|
+
EmbeddingConfig(
|
|
698
|
+
embedding_model=model_name,
|
|
699
|
+
embedding_endpoint_type="azure",
|
|
700
|
+
embedding_endpoint=model_endpoint,
|
|
701
|
+
embedding_dim=768,
|
|
702
|
+
embedding_chunk_size=300, # NOTE: max is 2048
|
|
703
|
+
handle=self.get_handle(model_name),
|
|
704
|
+
)
|
|
705
|
+
)
|
|
706
|
+
return configs
|
|
707
|
+
|
|
708
|
+
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
709
|
+
"""
|
|
710
|
+
This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
|
|
711
|
+
"""
|
|
712
|
+
return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)
|
|
713
|
+
|
|
714
|
+
class VLLMChatCompletionsProvider(Provider):
|
|
715
|
+
"""vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
|
|
716
|
+
|
|
717
|
+
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
|
718
|
+
name: str = "vllm"
|
|
719
|
+
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
|
720
|
+
|
|
721
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
722
|
+
# not supported with vLLM
|
|
723
|
+
from mirix.llm_api.openai import openai_get_model_list
|
|
724
|
+
|
|
725
|
+
assert self.base_url, "base_url is required for vLLM provider"
|
|
726
|
+
response = openai_get_model_list(self.base_url, api_key=None)
|
|
727
|
+
|
|
728
|
+
configs = []
|
|
729
|
+
for model in response["data"]:
|
|
730
|
+
configs.append(
|
|
731
|
+
LLMConfig(
|
|
732
|
+
model=model["id"],
|
|
733
|
+
model_endpoint_type="openai",
|
|
734
|
+
model_endpoint=self.base_url,
|
|
735
|
+
context_window=model["max_model_len"],
|
|
736
|
+
handle=self.get_handle(model["id"]),
|
|
737
|
+
)
|
|
738
|
+
)
|
|
739
|
+
return configs
|
|
740
|
+
|
|
741
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
742
|
+
# not supported with vLLM
|
|
743
|
+
return []
|
|
744
|
+
|
|
745
|
+
class VLLMCompletionsProvider(Provider):
|
|
746
|
+
"""This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
|
|
747
|
+
|
|
748
|
+
# NOTE: vLLM only serves one model at a time (so could configure that through env variables)
|
|
749
|
+
name: str = "vllm"
|
|
750
|
+
base_url: str = Field(..., description="Base URL for the vLLM API.")
|
|
751
|
+
default_prompt_formatter: str = Field(
|
|
752
|
+
...,
|
|
753
|
+
description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.",
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
757
|
+
# not supported with vLLM
|
|
758
|
+
from mirix.llm_api.openai import openai_get_model_list
|
|
759
|
+
|
|
760
|
+
response = openai_get_model_list(self.base_url, api_key=None)
|
|
761
|
+
|
|
762
|
+
configs = []
|
|
763
|
+
for model in response["data"]:
|
|
764
|
+
configs.append(
|
|
765
|
+
LLMConfig(
|
|
766
|
+
model=model["id"],
|
|
767
|
+
model_endpoint_type="vllm",
|
|
768
|
+
model_endpoint=self.base_url,
|
|
769
|
+
model_wrapper=self.default_prompt_formatter,
|
|
770
|
+
context_window=model["max_model_len"],
|
|
771
|
+
handle=self.get_handle(model["id"]),
|
|
772
|
+
)
|
|
773
|
+
)
|
|
774
|
+
return configs
|
|
775
|
+
|
|
776
|
+
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
777
|
+
# not supported with vLLM
|
|
778
|
+
return []
|
|
779
|
+
|
|
780
|
+
class CohereProvider(OpenAIProvider):
|
|
781
|
+
pass
|
|
782
|
+
|
|
783
|
+
class AnthropicBedrockProvider(Provider):
|
|
784
|
+
name: str = "bedrock"
|
|
785
|
+
aws_region: str = Field(..., description="AWS region for Bedrock")
|
|
786
|
+
|
|
787
|
+
def list_llm_models(self):
|
|
788
|
+
from mirix.llm_api.aws_bedrock import bedrock_get_model_list
|
|
789
|
+
|
|
790
|
+
models = bedrock_get_model_list(self.aws_region)
|
|
791
|
+
|
|
792
|
+
configs = []
|
|
793
|
+
for model_summary in models:
|
|
794
|
+
model_arn = model_summary["inferenceProfileArn"]
|
|
795
|
+
configs.append(
|
|
796
|
+
LLMConfig(
|
|
797
|
+
model=model_arn,
|
|
798
|
+
model_endpoint_type=self.name,
|
|
799
|
+
model_endpoint=None,
|
|
800
|
+
context_window=self.get_model_context_window(model_arn),
|
|
801
|
+
handle=self.get_handle(model_arn),
|
|
802
|
+
)
|
|
803
|
+
)
|
|
804
|
+
return configs
|
|
805
|
+
|
|
806
|
+
def list_embedding_models(self):
|
|
807
|
+
return []
|
|
808
|
+
|
|
809
|
+
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
|
810
|
+
# Context windows for Claude models
|
|
811
|
+
from mirix.llm_api.aws_bedrock import bedrock_get_model_context_window
|
|
812
|
+
|
|
813
|
+
return bedrock_get_model_context_window(model_name)
|
|
814
|
+
|
|
815
|
+
def get_handle(self, model_name: str) -> str:
|
|
816
|
+
return f"anthropic/{model_name}"
|