jl-ecms-client 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jl-ecms-client might be problematic. Click here for more details.

Files changed (53) hide show
  1. jl_ecms_client-0.2.8.dist-info/METADATA +295 -0
  2. jl_ecms_client-0.2.8.dist-info/RECORD +53 -0
  3. jl_ecms_client-0.2.8.dist-info/WHEEL +5 -0
  4. jl_ecms_client-0.2.8.dist-info/licenses/LICENSE +190 -0
  5. jl_ecms_client-0.2.8.dist-info/top_level.txt +1 -0
  6. mirix/client/__init__.py +14 -0
  7. mirix/client/client.py +405 -0
  8. mirix/client/constants.py +60 -0
  9. mirix/client/remote_client.py +1136 -0
  10. mirix/client/utils.py +34 -0
  11. mirix/helpers/__init__.py +1 -0
  12. mirix/helpers/converters.py +429 -0
  13. mirix/helpers/datetime_helpers.py +90 -0
  14. mirix/helpers/json_helpers.py +47 -0
  15. mirix/helpers/message_helpers.py +74 -0
  16. mirix/helpers/tool_rule_solver.py +166 -0
  17. mirix/schemas/__init__.py +1 -0
  18. mirix/schemas/agent.py +401 -0
  19. mirix/schemas/block.py +188 -0
  20. mirix/schemas/cloud_file_mapping.py +29 -0
  21. mirix/schemas/embedding_config.py +114 -0
  22. mirix/schemas/enums.py +69 -0
  23. mirix/schemas/environment_variables.py +82 -0
  24. mirix/schemas/episodic_memory.py +170 -0
  25. mirix/schemas/file.py +57 -0
  26. mirix/schemas/health.py +10 -0
  27. mirix/schemas/knowledge_vault.py +181 -0
  28. mirix/schemas/llm_config.py +187 -0
  29. mirix/schemas/memory.py +318 -0
  30. mirix/schemas/message.py +1315 -0
  31. mirix/schemas/mirix_base.py +107 -0
  32. mirix/schemas/mirix_message.py +411 -0
  33. mirix/schemas/mirix_message_content.py +230 -0
  34. mirix/schemas/mirix_request.py +39 -0
  35. mirix/schemas/mirix_response.py +183 -0
  36. mirix/schemas/openai/__init__.py +1 -0
  37. mirix/schemas/openai/chat_completion_request.py +122 -0
  38. mirix/schemas/openai/chat_completion_response.py +144 -0
  39. mirix/schemas/openai/chat_completions.py +127 -0
  40. mirix/schemas/openai/embedding_response.py +11 -0
  41. mirix/schemas/openai/openai.py +229 -0
  42. mirix/schemas/organization.py +38 -0
  43. mirix/schemas/procedural_memory.py +151 -0
  44. mirix/schemas/providers.py +816 -0
  45. mirix/schemas/resource_memory.py +134 -0
  46. mirix/schemas/sandbox_config.py +132 -0
  47. mirix/schemas/semantic_memory.py +162 -0
  48. mirix/schemas/source.py +96 -0
  49. mirix/schemas/step.py +53 -0
  50. mirix/schemas/tool.py +241 -0
  51. mirix/schemas/tool_rule.py +209 -0
  52. mirix/schemas/usage.py +31 -0
  53. mirix/schemas/user.py +67 -0
@@ -0,0 +1,816 @@
1
+ from datetime import datetime
2
+ from typing import List, Optional
3
+
4
+ from pydantic import Field, model_validator
5
+
6
+ from mirix.client.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
7
+ from mirix.llm_api.azure_openai import (
8
+ get_azure_chat_completions_endpoint,
9
+ get_azure_embeddings_endpoint,
10
+ )
11
+ from mirix.llm_api.azure_openai_constants import AZURE_MODEL_TO_CONTEXT_LENGTH
12
+ from mirix.log import get_logger
13
+ from mirix.schemas.embedding_config import EmbeddingConfig
14
+ from mirix.schemas.llm_config import LLMConfig
15
+ from mirix.schemas.mirix_base import MirixBase
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class ProviderBase(MirixBase):
21
+ __id_prefix__ = "provider"
22
+
23
+ class Provider(ProviderBase):
24
+ id: Optional[str] = Field(
25
+ None,
26
+ description="The id of the provider, lazily created by the database manager.",
27
+ )
28
+ name: str = Field(..., description="The name of the provider")
29
+ api_key: Optional[str] = Field(
30
+ None, description="API key used for requests to the provider."
31
+ )
32
+ organization_id: Optional[str] = Field(
33
+ None, description="The organization id of the user"
34
+ )
35
+ updated_at: Optional[datetime] = Field(
36
+ None, description="The last update timestamp of the provider."
37
+ )
38
+
39
+ def resolve_identifier(self):
40
+ if not self.id:
41
+ self.id = ProviderBase._generate_id(prefix=ProviderBase.__id_prefix__)
42
+
43
+ def list_llm_models(self) -> List[LLMConfig]:
44
+ return []
45
+
46
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
47
+ return []
48
+
49
+ def get_model_context_window(self, model_name: str) -> Optional[int]:
50
+ raise NotImplementedError
51
+
52
+ def provider_tag(self) -> str:
53
+ """String representation of the provider for display purposes"""
54
+ raise NotImplementedError
55
+
56
+ def get_handle(self, model_name: str) -> str:
57
+ return f"{self.name}/{model_name}"
58
+
59
+ class ProviderCreate(ProviderBase):
60
+ name: str = Field(..., description="The name of the provider.")
61
+ api_key: str = Field(..., description="API key used for requests to the provider.")
62
+
63
+ class ProviderUpdate(ProviderBase):
64
+ id: str = Field(..., description="The id of the provider to update.")
65
+ api_key: str = Field(..., description="API key used for requests to the provider.")
66
+
67
+ class MirixProvider(Provider):
68
+ name: str = "mirix"
69
+
70
+ def list_llm_models(self) -> List[LLMConfig]:
71
+ return [
72
+ LLMConfig(
73
+ model="mirix-free", # NOTE: renamed
74
+ model_endpoint_type="openai",
75
+ model_endpoint="https://inference.memgpt.ai",
76
+ context_window=8192,
77
+ handle=self.get_handle("mirix-free"),
78
+ )
79
+ ]
80
+
81
+ def list_embedding_models(self):
82
+ return [
83
+ EmbeddingConfig(
84
+ embedding_model="mirix-free", # NOTE: renamed
85
+ embedding_endpoint_type="hugging-face",
86
+ embedding_endpoint="https://embeddings.memgpt.ai",
87
+ embedding_dim=1024,
88
+ embedding_chunk_size=300,
89
+ handle=self.get_handle("mirix-free"),
90
+ )
91
+ ]
92
+
93
+ class OpenAIProvider(Provider):
94
+ name: str = "openai"
95
+ api_key: str = Field(..., description="API key for the OpenAI API.")
96
+ base_url: str = Field(..., description="Base URL for the OpenAI API.")
97
+
98
+ def list_llm_models(self) -> List[LLMConfig]:
99
+ from mirix.llm_api.openai import openai_get_model_list
100
+
101
+ # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
102
+ # See: https://openrouter.ai/docs/requests
103
+ extra_params = (
104
+ {"supported_parameters": "tools"}
105
+ if "openrouter.ai" in self.base_url
106
+ else None
107
+ )
108
+ response = openai_get_model_list(
109
+ self.base_url, api_key=self.api_key, extra_params=extra_params
110
+ )
111
+
112
+ # TogetherAI's response is missing the 'data' field
113
+ # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
114
+ if "data" in response:
115
+ data = response["data"]
116
+ else:
117
+ data = response
118
+
119
+ configs = []
120
+ for model in data:
121
+ assert "id" in model, f"OpenAI model missing 'id' field: {model}"
122
+ model_name = model["id"]
123
+
124
+ if "context_length" in model:
125
+ # Context length is returned in OpenRouter as "context_length"
126
+ context_window_size = model["context_length"]
127
+ else:
128
+ context_window_size = self.get_model_context_window_size(model_name)
129
+
130
+ if not context_window_size:
131
+ continue
132
+
133
+ # TogetherAI includes the type, which we can use to filter out embedding models
134
+ if self.base_url == "https://api.together.ai/v1":
135
+ if "type" in model and model["type"] != "chat":
136
+ continue
137
+
138
+ # for TogetherAI, we need to skip the models that don't support JSON mode / function calling
139
+ # requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: {
140
+ # "error": {
141
+ # "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling",
142
+ # "type": "invalid_request_error",
143
+ # "param": null,
144
+ # "code": "constraints_model"
145
+ # }
146
+ # }
147
+ if "config" not in model:
148
+ continue
149
+ if "chat_template" not in model["config"]:
150
+ continue
151
+ if model["config"]["chat_template"] is None:
152
+ continue
153
+ if "tools" not in model["config"]["chat_template"]:
154
+ continue
155
+ # if "config" in data and "chat_template" in data["config"] and "tools" not in data["config"]["chat_template"]:
156
+ # continue
157
+
158
+ configs.append(
159
+ LLMConfig(
160
+ model=model_name,
161
+ model_endpoint_type="openai",
162
+ model_endpoint=self.base_url,
163
+ context_window=context_window_size,
164
+ handle=self.get_handle(model_name),
165
+ )
166
+ )
167
+
168
+ # for OpenAI, sort in reverse order
169
+ if self.base_url == "https://api.openai.com/v1":
170
+ # alphnumeric sort
171
+ configs.sort(key=lambda x: x.model, reverse=True)
172
+
173
+ return configs
174
+
175
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
176
+ # TODO: actually automatically list models
177
+ return [
178
+ EmbeddingConfig(
179
+ embedding_model="text-embedding-3-small",
180
+ embedding_endpoint_type="openai",
181
+ embedding_endpoint="https://api.openai.com/v1",
182
+ embedding_dim=1536,
183
+ embedding_chunk_size=300,
184
+ handle=self.get_handle("text-embedding-3-small"),
185
+ ),
186
+ EmbeddingConfig(
187
+ embedding_model="text-embedding-3-small",
188
+ embedding_endpoint_type="openai",
189
+ embedding_endpoint="https://api.openai.com/v1",
190
+ embedding_dim=2000,
191
+ embedding_chunk_size=300,
192
+ handle=self.get_handle("text-embedding-3-small"),
193
+ ),
194
+ EmbeddingConfig(
195
+ embedding_model="text-embedding-3-large",
196
+ embedding_endpoint_type="openai",
197
+ embedding_endpoint="https://api.openai.com/v1",
198
+ embedding_dim=2000,
199
+ embedding_chunk_size=300,
200
+ handle=self.get_handle("text-embedding-3-large"),
201
+ ),
202
+ ]
203
+
204
+ def get_model_context_window_size(self, model_name: str):
205
+ if model_name in LLM_MAX_TOKENS:
206
+ return LLM_MAX_TOKENS[model_name]
207
+ else:
208
+ return None
209
+
210
+ class AnthropicProvider(Provider):
211
+ name: str = "anthropic"
212
+ api_key: str = Field(..., description="API key for the Anthropic API.")
213
+ base_url: str = "https://api.anthropic.com/v1"
214
+
215
+ def list_llm_models(self) -> List[LLMConfig]:
216
+ from mirix.llm_api.anthropic import anthropic_get_model_list
217
+
218
+ models = anthropic_get_model_list(self.base_url, api_key=self.api_key)
219
+
220
+ configs = []
221
+ for model in models:
222
+ configs.append(
223
+ LLMConfig(
224
+ model=model["name"],
225
+ model_endpoint_type="anthropic",
226
+ model_endpoint=self.base_url,
227
+ context_window=model["context_window"],
228
+ handle=self.get_handle(model["name"]),
229
+ )
230
+ )
231
+ return configs
232
+
233
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
234
+ return []
235
+
236
+ class MistralProvider(Provider):
237
+ name: str = "mistral"
238
+ api_key: str = Field(..., description="API key for the Mistral API.")
239
+ base_url: str = "https://api.mistral.ai/v1"
240
+
241
+ def list_llm_models(self) -> List[LLMConfig]:
242
+ from mirix.llm_api.mistral import mistral_get_model_list
243
+
244
+ # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
245
+ # See: https://openrouter.ai/docs/requests
246
+ response = mistral_get_model_list(self.base_url, api_key=self.api_key)
247
+
248
+ assert "data" in response, (
249
+ f"Mistral model query response missing 'data' field: {response}"
250
+ )
251
+
252
+ configs = []
253
+ for model in response["data"]:
254
+ # If model has chat completions and function calling enabled
255
+ if (
256
+ model["capabilities"]["completion_chat"]
257
+ and model["capabilities"]["function_calling"]
258
+ ):
259
+ configs.append(
260
+ LLMConfig(
261
+ model=model["id"],
262
+ model_endpoint_type="openai",
263
+ model_endpoint=self.base_url,
264
+ context_window=model["max_context_length"],
265
+ handle=self.get_handle(model["id"]),
266
+ )
267
+ )
268
+
269
+ return configs
270
+
271
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
272
+ # Not supported for mistral
273
+ return []
274
+
275
+ def get_model_context_window(self, model_name: str) -> Optional[int]:
276
+ # Redoing this is fine because it's a pretty lightweight call
277
+ models = self.list_llm_models()
278
+
279
+ for m in models:
280
+ if model_name in m["id"]:
281
+ return int(m["max_context_length"])
282
+
283
+ return None
284
+
285
+ class OllamaProvider(OpenAIProvider):
286
+ """Ollama provider that uses the native /api/generate endpoint
287
+
288
+ See: https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
289
+ """
290
+
291
+ name: str = "ollama"
292
+ base_url: str = Field(..., description="Base URL for the Ollama API.")
293
+ api_key: Optional[str] = Field(
294
+ None, description="API key for the Ollama API (default: `None`)."
295
+ )
296
+ default_prompt_formatter: str = Field(
297
+ ...,
298
+ description="Default prompt formatter (aka model wrapper) to use on a /completions style API.",
299
+ )
300
+
301
+ def list_llm_models(self) -> List[LLMConfig]:
302
+ # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
303
+ import requests
304
+
305
+ response = requests.get(f"{self.base_url}/api/tags")
306
+ if response.status_code != 200:
307
+ raise Exception(f"Failed to list Ollama models: {response.text}")
308
+ response_json = response.json()
309
+
310
+ configs = []
311
+ for model in response_json["models"]:
312
+ context_window = self.get_model_context_window(model["name"])
313
+ if context_window is None:
314
+ logger.debug("Ollama model %s has no context window", model['name'])
315
+ continue
316
+ configs.append(
317
+ LLMConfig(
318
+ model=model["name"],
319
+ model_endpoint_type="ollama",
320
+ model_endpoint=self.base_url,
321
+ model_wrapper=self.default_prompt_formatter,
322
+ context_window=context_window,
323
+ handle=self.get_handle(model["name"]),
324
+ )
325
+ )
326
+ return configs
327
+
328
+ def get_model_context_window(self, model_name: str) -> Optional[int]:
329
+ import requests
330
+
331
+ response = requests.post(
332
+ f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}
333
+ )
334
+ response_json = response.json()
335
+
336
+ ## thank you vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L1675
337
+ # possible_keys = [
338
+ # # OPT
339
+ # "max_position_embeddings",
340
+ # # GPT-2
341
+ # "n_positions",
342
+ # # MPT
343
+ # "max_seq_len",
344
+ # # ChatGLM2
345
+ # "seq_length",
346
+ # # Command-R
347
+ # "model_max_length",
348
+ # # Others
349
+ # "max_sequence_length",
350
+ # "max_seq_length",
351
+ # "seq_len",
352
+ # ]
353
+ # max_position_embeddings
354
+ # parse model cards: nous, dolphon, llama
355
+ if "model_info" not in response_json:
356
+ if "error" in response_json:
357
+ logger.error(
358
+ f"Ollama fetch model info error for {model_name}: {response_json['error']}"
359
+ )
360
+ return None
361
+ for key, value in response_json["model_info"].items():
362
+ if "context_length" in key:
363
+ return value
364
+ return None
365
+
366
+ def get_model_embedding_dim(self, model_name: str):
367
+ import requests
368
+
369
+ response = requests.post(
370
+ f"{self.base_url}/api/show", json={"name": model_name, "verbose": True}
371
+ )
372
+ response_json = response.json()
373
+ if "model_info" not in response_json:
374
+ if "error" in response_json:
375
+ logger.error(
376
+ f"Ollama fetch model info error for {model_name}: {response_json['error']}"
377
+ )
378
+ return None
379
+ for key, value in response_json["model_info"].items():
380
+ if "embedding_length" in key:
381
+ return value
382
+ return None
383
+
384
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
385
+ # https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
386
+ import requests
387
+
388
+ response = requests.get(f"{self.base_url}/api/tags")
389
+ if response.status_code != 200:
390
+ raise Exception(f"Failed to list Ollama models: {response.text}")
391
+ response_json = response.json()
392
+
393
+ configs = []
394
+ for model in response_json["models"]:
395
+ embedding_dim = self.get_model_embedding_dim(model["name"])
396
+ if not embedding_dim:
397
+ logger.debug("Ollama model %s has no embedding dimension", model['name'])
398
+ continue
399
+ configs.append(
400
+ EmbeddingConfig(
401
+ embedding_model=model["name"],
402
+ embedding_endpoint_type="ollama",
403
+ embedding_endpoint=self.base_url,
404
+ embedding_dim=embedding_dim,
405
+ embedding_chunk_size=300,
406
+ handle=self.get_handle(model["name"]),
407
+ )
408
+ )
409
+ return configs
410
+
411
+ class GroqProvider(OpenAIProvider):
412
+ name: str = "groq"
413
+ base_url: str = "https://api.groq.com/openai/v1"
414
+ api_key: str = Field(..., description="API key for the Groq API.")
415
+
416
+ def list_llm_models(self) -> List[LLMConfig]:
417
+ from mirix.llm_api.openai import openai_get_model_list
418
+
419
+ response = openai_get_model_list(self.base_url, api_key=self.api_key)
420
+ configs = []
421
+ for model in response["data"]:
422
+ if "context_window" not in model:
423
+ continue
424
+ configs.append(
425
+ LLMConfig(
426
+ model=model["id"],
427
+ model_endpoint_type="groq",
428
+ model_endpoint=self.base_url,
429
+ context_window=model["context_window"],
430
+ handle=self.get_handle(model["id"]),
431
+ )
432
+ )
433
+ return configs
434
+
435
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
436
+ return []
437
+
438
+ def get_model_context_window_size(self, model_name: str):
439
+ raise NotImplementedError
440
+
441
+ class TogetherProvider(OpenAIProvider):
442
+ """TogetherAI provider that uses the /completions API
443
+
444
+ TogetherAI can also be used via the /chat/completions API
445
+ by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key
446
+ and API URL, however /completions is preferred because their /chat/completions
447
+ function calling support is limited.
448
+ """
449
+
450
+ name: str = "together"
451
+ base_url: str = "https://api.together.ai/v1"
452
+ api_key: str = Field(..., description="API key for the TogetherAI API.")
453
+ default_prompt_formatter: str = Field(
454
+ ...,
455
+ description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.",
456
+ )
457
+
458
+ def list_llm_models(self) -> List[LLMConfig]:
459
+ from mirix.llm_api.openai import openai_get_model_list
460
+
461
+ response = openai_get_model_list(self.base_url, api_key=self.api_key)
462
+
463
+ # TogetherAI's response is missing the 'data' field
464
+ # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
465
+ if "data" in response:
466
+ data = response["data"]
467
+ else:
468
+ data = response
469
+
470
+ configs = []
471
+ for model in data:
472
+ assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
473
+ model_name = model["id"]
474
+
475
+ if "context_length" in model:
476
+ # Context length is returned in OpenRouter as "context_length"
477
+ context_window_size = model["context_length"]
478
+ else:
479
+ context_window_size = self.get_model_context_window_size(model_name)
480
+
481
+ # We need the context length for embeddings too
482
+ if not context_window_size:
483
+ continue
484
+
485
+ # Skip models that are too small for Mirix
486
+ if context_window_size <= MIN_CONTEXT_WINDOW:
487
+ continue
488
+
489
+ # TogetherAI includes the type, which we can use to filter for embedding models
490
+ if "type" in model and model["type"] not in ["chat", "language"]:
491
+ continue
492
+
493
+ configs.append(
494
+ LLMConfig(
495
+ model=model_name,
496
+ model_endpoint_type="together",
497
+ model_endpoint=self.base_url,
498
+ model_wrapper=self.default_prompt_formatter,
499
+ context_window=context_window_size,
500
+ handle=self.get_handle(model_name),
501
+ )
502
+ )
503
+
504
+ return configs
505
+
506
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
507
+ # TODO renable once we figure out how to pass API keys through properly
508
+ return []
509
+
510
+ # from mirix.llm_api.openai import openai_get_model_list
511
+
512
+ # response = openai_get_model_list(self.base_url, api_key=self.api_key)
513
+
514
+ # # TogetherAI's response is missing the 'data' field
515
+ # # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
516
+ # if "data" in response:
517
+ # data = response["data"]
518
+ # else:
519
+ # data = response
520
+
521
+ # configs = []
522
+ # for model in data:
523
+ # assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
524
+ # model_name = model["id"]
525
+
526
+ # if "context_length" in model:
527
+ # # Context length is returned in OpenRouter as "context_length"
528
+ # context_window_size = model["context_length"]
529
+ # else:
530
+ # context_window_size = self.get_model_context_window_size(model_name)
531
+
532
+ # if not context_window_size:
533
+ # continue
534
+
535
+ # # TogetherAI includes the type, which we can use to filter out embedding models
536
+ # if "type" in model and model["type"] not in ["embedding"]:
537
+ # continue
538
+
539
+ # configs.append(
540
+ # EmbeddingConfig(
541
+ # embedding_model=model_name,
542
+ # embedding_endpoint_type="openai",
543
+ # embedding_endpoint=self.base_url,
544
+ # embedding_dim=context_window_size,
545
+ # embedding_chunk_size=300, # TODO: change?
546
+ # )
547
+ # )
548
+
549
+ # return configs
550
+
551
+ class GoogleAIProvider(Provider):
552
+ # gemini
553
+ name: str = "google_ai"
554
+ api_key: str = Field(..., description="API key for the Google AI API.")
555
+ base_url: str = "https://generativelanguage.googleapis.com"
556
+
557
+ def list_llm_models(self):
558
+ from mirix.llm_api.google_ai import google_ai_get_model_list
559
+
560
+ model_options = google_ai_get_model_list(
561
+ base_url=self.base_url, api_key=self.api_key
562
+ )
563
+ # filter by 'generateContent' models
564
+ model_options = [
565
+ mo
566
+ for mo in model_options
567
+ if "generateContent" in mo["supportedGenerationMethods"]
568
+ ]
569
+ model_options = [str(m["name"]) for m in model_options]
570
+
571
+ # filter by model names
572
+ model_options = [
573
+ mo[len("models/") :] if mo.startswith("models/") else mo
574
+ for mo in model_options
575
+ ]
576
+
577
+ # TODO remove manual filtering for gemini-pro
578
+ # Add support for all gemini models
579
+ model_options = [mo for mo in model_options if str(mo).startswith("gemini-")]
580
+
581
+ configs = []
582
+ for model in model_options:
583
+ configs.append(
584
+ LLMConfig(
585
+ model=model,
586
+ model_endpoint_type="google_ai",
587
+ model_endpoint=self.base_url,
588
+ context_window=self.get_model_context_window(model),
589
+ handle=self.get_handle(model),
590
+ )
591
+ )
592
+ return configs
593
+
594
+ def list_embedding_models(self):
595
+ from mirix.llm_api.google_ai import google_ai_get_model_list
596
+
597
+ # TODO: use base_url instead
598
+ model_options = google_ai_get_model_list(
599
+ base_url=self.base_url, api_key=self.api_key
600
+ )
601
+ # filter by 'generateContent' models
602
+ model_options = [
603
+ mo
604
+ for mo in model_options
605
+ if "embedContent" in mo["supportedGenerationMethods"]
606
+ ]
607
+ model_options = [str(m["name"]) for m in model_options]
608
+ model_options = [
609
+ mo[len("models/") :] if mo.startswith("models/") else mo
610
+ for mo in model_options
611
+ ]
612
+
613
+ configs = []
614
+ for model in model_options:
615
+ configs.append(
616
+ EmbeddingConfig(
617
+ embedding_model=model,
618
+ embedding_endpoint_type="google_ai",
619
+ embedding_endpoint=self.base_url,
620
+ embedding_dim=768,
621
+ embedding_chunk_size=300, # NOTE: max is 2048
622
+ handle=self.get_handle(model),
623
+ )
624
+ )
625
+ return configs
626
+
627
+ def get_model_context_window(self, model_name: str) -> Optional[int]:
628
+ from mirix.llm_api.google_ai import google_ai_get_model_context_window
629
+
630
+ return google_ai_get_model_context_window(
631
+ self.base_url, self.api_key, model_name
632
+ )
633
+
634
+ class AzureProvider(Provider):
635
+ name: str = "azure"
636
+ latest_api_version: str = "2024-09-01-preview" # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
637
+ base_url: str = Field(
638
+ ...,
639
+ description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://mirix.openai.azure.com`.",
640
+ )
641
+ api_key: str = Field(..., description="API key for the Azure API.")
642
+ api_version: str = Field(
643
+ latest_api_version, description="API version for the Azure API"
644
+ )
645
+
646
+ @model_validator(mode="before")
647
+ def set_default_api_version(cls, values):
648
+ """
649
+ This ensures that api_version is always set to the default if None is passed in.
650
+ """
651
+ if values.get("api_version") is None:
652
+ values["api_version"] = cls.model_fields["latest_api_version"].default
653
+ return values
654
+
655
+ def list_llm_models(self) -> List[LLMConfig]:
656
+ from mirix.llm_api.azure_openai import (
657
+ azure_openai_get_chat_completion_model_list,
658
+ )
659
+
660
+ model_options = azure_openai_get_chat_completion_model_list(
661
+ self.base_url, api_key=self.api_key, api_version=self.api_version
662
+ )
663
+ configs = []
664
+ for model_option in model_options:
665
+ model_name = model_option["id"]
666
+ context_window_size = self.get_model_context_window(model_name)
667
+ model_endpoint = get_azure_chat_completions_endpoint(
668
+ self.base_url, model_name, self.api_version
669
+ )
670
+ configs.append(
671
+ LLMConfig(
672
+ model=model_name,
673
+ model_endpoint_type="azure",
674
+ model_endpoint=model_endpoint,
675
+ context_window=context_window_size,
676
+ handle=self.get_handle(model_name),
677
+ ),
678
+ )
679
+ return configs
680
+
681
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
682
+ from mirix.llm_api.azure_openai import azure_openai_get_embeddings_model_list
683
+
684
+ model_options = azure_openai_get_embeddings_model_list(
685
+ self.base_url,
686
+ api_key=self.api_key,
687
+ api_version=self.api_version,
688
+ require_embedding_in_name=True,
689
+ )
690
+ configs = []
691
+ for model_option in model_options:
692
+ model_name = model_option["id"]
693
+ model_endpoint = get_azure_embeddings_endpoint(
694
+ self.base_url, model_name, self.api_version
695
+ )
696
+ configs.append(
697
+ EmbeddingConfig(
698
+ embedding_model=model_name,
699
+ embedding_endpoint_type="azure",
700
+ embedding_endpoint=model_endpoint,
701
+ embedding_dim=768,
702
+ embedding_chunk_size=300, # NOTE: max is 2048
703
+ handle=self.get_handle(model_name),
704
+ )
705
+ )
706
+ return configs
707
+
708
+ def get_model_context_window(self, model_name: str) -> Optional[int]:
709
+ """
710
+ This is hardcoded for now, since there is no API endpoints to retrieve metadata for a model.
711
+ """
712
+ return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, 4096)
713
+
714
+ class VLLMChatCompletionsProvider(Provider):
715
+ """vLLM provider that treats vLLM as an OpenAI /chat/completions proxy"""
716
+
717
+ # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
718
+ name: str = "vllm"
719
+ base_url: str = Field(..., description="Base URL for the vLLM API.")
720
+
721
+ def list_llm_models(self) -> List[LLMConfig]:
722
+ # not supported with vLLM
723
+ from mirix.llm_api.openai import openai_get_model_list
724
+
725
+ assert self.base_url, "base_url is required for vLLM provider"
726
+ response = openai_get_model_list(self.base_url, api_key=None)
727
+
728
+ configs = []
729
+ for model in response["data"]:
730
+ configs.append(
731
+ LLMConfig(
732
+ model=model["id"],
733
+ model_endpoint_type="openai",
734
+ model_endpoint=self.base_url,
735
+ context_window=model["max_model_len"],
736
+ handle=self.get_handle(model["id"]),
737
+ )
738
+ )
739
+ return configs
740
+
741
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
742
+ # not supported with vLLM
743
+ return []
744
+
745
+ class VLLMCompletionsProvider(Provider):
746
+ """This uses /completions API as the backend, not /chat/completions, so we need to specify a model wrapper"""
747
+
748
+ # NOTE: vLLM only serves one model at a time (so could configure that through env variables)
749
+ name: str = "vllm"
750
+ base_url: str = Field(..., description="Base URL for the vLLM API.")
751
+ default_prompt_formatter: str = Field(
752
+ ...,
753
+ description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.",
754
+ )
755
+
756
+ def list_llm_models(self) -> List[LLMConfig]:
757
+ # not supported with vLLM
758
+ from mirix.llm_api.openai import openai_get_model_list
759
+
760
+ response = openai_get_model_list(self.base_url, api_key=None)
761
+
762
+ configs = []
763
+ for model in response["data"]:
764
+ configs.append(
765
+ LLMConfig(
766
+ model=model["id"],
767
+ model_endpoint_type="vllm",
768
+ model_endpoint=self.base_url,
769
+ model_wrapper=self.default_prompt_formatter,
770
+ context_window=model["max_model_len"],
771
+ handle=self.get_handle(model["id"]),
772
+ )
773
+ )
774
+ return configs
775
+
776
+ def list_embedding_models(self) -> List[EmbeddingConfig]:
777
+ # not supported with vLLM
778
+ return []
779
+
780
+ class CohereProvider(OpenAIProvider):
781
+ pass
782
+
783
+ class AnthropicBedrockProvider(Provider):
784
+ name: str = "bedrock"
785
+ aws_region: str = Field(..., description="AWS region for Bedrock")
786
+
787
+ def list_llm_models(self):
788
+ from mirix.llm_api.aws_bedrock import bedrock_get_model_list
789
+
790
+ models = bedrock_get_model_list(self.aws_region)
791
+
792
+ configs = []
793
+ for model_summary in models:
794
+ model_arn = model_summary["inferenceProfileArn"]
795
+ configs.append(
796
+ LLMConfig(
797
+ model=model_arn,
798
+ model_endpoint_type=self.name,
799
+ model_endpoint=None,
800
+ context_window=self.get_model_context_window(model_arn),
801
+ handle=self.get_handle(model_arn),
802
+ )
803
+ )
804
+ return configs
805
+
806
+ def list_embedding_models(self):
807
+ return []
808
+
809
+ def get_model_context_window(self, model_name: str) -> Optional[int]:
810
+ # Context windows for Claude models
811
+ from mirix.llm_api.aws_bedrock import bedrock_get_model_context_window
812
+
813
+ return bedrock_get_model_context_window(model_name)
814
+
815
+ def get_handle(self, model_name: str) -> str:
816
+ return f"anthropic/{model_name}"