litellm-enterprise 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,7 +40,7 @@ class EnterpriseCallbackControls:
40
40
  #########################################################
41
41
  # premium user check
42
42
  #########################################################
43
- if not EnterpriseCallbackControls._premium_user_check():
43
+ if not EnterpriseCallbackControls._should_allow_dynamic_callback_disabling():
44
44
  return False
45
45
  #########################################################
46
46
  if isinstance(callback, str):
@@ -84,8 +84,15 @@ class EnterpriseCallbackControls:
84
84
  return None
85
85
 
86
86
  @staticmethod
87
- def _premium_user_check():
87
+ def _should_allow_dynamic_callback_disabling():
88
+ import litellm
88
89
  from litellm.proxy.proxy_server import premium_user
90
+
91
+ # Check if admin has disabled this feature
92
+ if litellm.allow_dynamic_callback_disabling is not True:
93
+ verbose_logger.debug("Dynamic callback disabling is disabled by admin via litellm.allow_dynamic_callback_disabling")
94
+ return False
95
+
89
96
  if premium_user:
90
97
  return True
91
98
  verbose_logger.warning(f"Disabling callbacks using request headers is an enterprise feature. {CommonProxyErrors.not_premium_user.value}")
@@ -9,7 +9,7 @@ Callback to log events to a Generic API Endpoint
9
9
  import asyncio
10
10
  import os
11
11
  import traceback
12
- import uuid
12
+ from litellm._uuid import uuid
13
13
  from typing import Dict, List, Optional, Union
14
14
 
15
15
  import litellm
@@ -23,7 +23,7 @@ import litellm
23
23
  from litellm._logging import verbose_proxy_logger
24
24
  from litellm.integrations.custom_logger import CustomLogger
25
25
  from litellm.proxy._types import UserAPIKeyAuth
26
- from litellm.types.utils import Choices, ModelResponse
26
+ from litellm.types.utils import CallTypesLiteral, Choices, ModelResponse
27
27
 
28
28
 
29
29
  class _ENTERPRISE_LlamaGuard(CustomLogger):
@@ -98,15 +98,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
98
98
  self,
99
99
  data: dict,
100
100
  user_api_key_dict: UserAPIKeyAuth,
101
- call_type: Literal[
102
- "completion",
103
- "embeddings",
104
- "image_generation",
105
- "moderation",
106
- "audio_transcription",
107
- "responses",
108
- "mcp_call",
109
- ],
101
+ call_type: CallTypesLiteral,
110
102
  ):
111
103
  """
112
104
  - Calls the Llama Guard Endpoint
@@ -17,6 +17,7 @@ from litellm._logging import verbose_proxy_logger
17
17
  from litellm.integrations.custom_logger import CustomLogger
18
18
  from litellm.proxy._types import UserAPIKeyAuth
19
19
  from litellm.secret_managers.main import get_secret_str
20
+ from litellm.types.utils import CallTypesLiteral
20
21
  from litellm.utils import get_formatted_prompt
21
22
 
22
23
 
@@ -120,15 +121,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
120
121
  self,
121
122
  data: dict,
122
123
  user_api_key_dict: UserAPIKeyAuth,
123
- call_type: Literal[
124
- "completion",
125
- "embeddings",
126
- "image_generation",
127
- "moderation",
128
- "audio_transcription",
129
- "responses",
130
- "mcp_call",
131
- ],
124
+ call_type: CallTypesLiteral,
132
125
  ):
133
126
  """
134
127
  - Calls the LLM Guard Endpoint
@@ -31,6 +31,7 @@ from litellm.types.integrations.pagerduty import (
31
31
  PagerDutyRequestBody,
32
32
  )
33
33
  from litellm.types.utils import (
34
+ CallTypesLiteral,
34
35
  StandardLoggingPayload,
35
36
  StandardLoggingPayloadErrorInformation,
36
37
  )
@@ -119,6 +120,7 @@ class PagerDutyAlerting(SlackAlerting):
119
120
  user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
120
121
  user_api_key_user_email=_meta.get("user_api_key_user_email"),
121
122
  user_api_key_request_route=_meta.get("user_api_key_request_route"),
123
+ user_api_key_auth_metadata=_meta.get("user_api_key_auth_metadata"),
122
124
  )
123
125
  )
124
126
 
@@ -141,17 +143,7 @@ class PagerDutyAlerting(SlackAlerting):
141
143
  user_api_key_dict: UserAPIKeyAuth,
142
144
  cache: DualCache,
143
145
  data: dict,
144
- call_type: Literal[
145
- "completion",
146
- "text_completion",
147
- "embeddings",
148
- "image_generation",
149
- "moderation",
150
- "audio_transcription",
151
- "pass_through_endpoint",
152
- "rerank",
153
- "mcp_call",
154
- ],
146
+ call_type: CallTypesLiteral,
155
147
  ) -> Optional[Union[Exception, str, dict]]:
156
148
  """
157
149
  Example of detecting hanging requests by waiting a given threshold.
@@ -196,7 +188,11 @@ class PagerDutyAlerting(SlackAlerting):
196
188
  user_api_key_alias=user_api_key_dict.key_alias,
197
189
  user_api_key_spend=user_api_key_dict.spend,
198
190
  user_api_key_max_budget=user_api_key_dict.max_budget,
199
- user_api_key_budget_reset_at=user_api_key_dict.budget_reset_at.isoformat() if user_api_key_dict.budget_reset_at else None,
191
+ user_api_key_budget_reset_at=(
192
+ user_api_key_dict.budget_reset_at.isoformat()
193
+ if user_api_key_dict.budget_reset_at
194
+ else None
195
+ ),
200
196
  user_api_key_org_id=user_api_key_dict.org_id,
201
197
  user_api_key_team_id=user_api_key_dict.team_id,
202
198
  user_api_key_user_id=user_api_key_dict.user_id,
@@ -204,6 +200,7 @@ class PagerDutyAlerting(SlackAlerting):
204
200
  user_api_key_end_user_id=user_api_key_dict.end_user_id,
205
201
  user_api_key_user_email=user_api_key_dict.user_email,
206
202
  user_api_key_request_route=user_api_key_dict.request_route,
203
+ user_api_key_auth_metadata=user_api_key_dict.metadata,
207
204
  )
208
205
  )
209
206
 
@@ -11,6 +11,7 @@ from litellm_enterprise.types.enterprise_callbacks.send_emails import (
11
11
  EmailEvent,
12
12
  EmailParams,
13
13
  SendKeyCreatedEmailEvent,
14
+ SendKeyRotatedEmailEvent,
14
15
  )
15
16
 
16
17
  from litellm._logging import verbose_proxy_logger
@@ -19,10 +20,14 @@ from litellm.integrations.email_templates.email_footer import EMAIL_FOOTER
19
20
  from litellm.integrations.email_templates.key_created_email import (
20
21
  KEY_CREATED_EMAIL_TEMPLATE,
21
22
  )
23
+ from litellm.integrations.email_templates.key_rotated_email import (
24
+ KEY_ROTATED_EMAIL_TEMPLATE,
25
+ )
22
26
  from litellm.integrations.email_templates.user_invitation_email import (
23
27
  USER_INVITATION_EMAIL_TEMPLATE,
24
28
  )
25
29
  from litellm.proxy._types import InvitationNew, UserAPIKeyAuth, WebhookEvent
30
+ from litellm.secret_managers.main import get_secret_bool
26
31
  from litellm.types.integrations.slack_alerting import LITELLM_LOGO_URL
27
32
 
28
33
 
@@ -32,6 +37,7 @@ class BaseEmailLogger(CustomLogger):
32
37
  DEFAULT_SUBJECT_TEMPLATES = {
33
38
  EmailEvent.new_user_invitation: "LiteLLM: {event_message}",
34
39
  EmailEvent.virtual_key_created: "LiteLLM: {event_message}",
40
+ EmailEvent.virtual_key_rotated: "LiteLLM: {event_message}",
35
41
  }
36
42
 
37
43
  async def send_user_invitation_email(self, event: WebhookEvent):
@@ -83,11 +89,58 @@ class BaseEmailLogger(CustomLogger):
83
89
  f"send_key_created_email_event: {json.dumps(send_key_created_email_event, indent=4, default=str)}"
84
90
  )
85
91
 
92
+ # Check if API key should be included in email
93
+ include_api_key = get_secret_bool(secret_name="EMAIL_INCLUDE_API_KEY", default_value=True)
94
+ if include_api_key is None:
95
+ include_api_key = True # Default to True if not set
96
+ key_token_display = send_key_created_email_event.virtual_key if include_api_key else "[Key hidden for security - retrieve from dashboard]"
97
+
86
98
  email_html_content = KEY_CREATED_EMAIL_TEMPLATE.format(
87
99
  email_logo_url=email_params.logo_url,
88
100
  recipient_email=email_params.recipient_email,
89
101
  key_budget=self._format_key_budget(send_key_created_email_event.max_budget),
90
- key_token=send_key_created_email_event.virtual_key,
102
+ key_token=key_token_display,
103
+ base_url=email_params.base_url,
104
+ email_support_contact=email_params.support_contact,
105
+ email_footer=email_params.signature,
106
+ )
107
+
108
+ await self.send_email(
109
+ from_email=self.DEFAULT_LITELLM_EMAIL,
110
+ to_email=[email_params.recipient_email],
111
+ subject=email_params.subject,
112
+ html_body=email_html_content,
113
+ )
114
+ pass
115
+
116
+ async def send_key_rotated_email(
117
+ self, send_key_rotated_email_event: SendKeyRotatedEmailEvent
118
+ ):
119
+ """
120
+ Send email to user after rotating key for the user
121
+ """
122
+ email_params = await self._get_email_params(
123
+ user_id=send_key_rotated_email_event.user_id,
124
+ user_email=send_key_rotated_email_event.user_email,
125
+ email_event=EmailEvent.virtual_key_rotated,
126
+ event_message=send_key_rotated_email_event.event_message,
127
+ )
128
+
129
+ verbose_proxy_logger.debug(
130
+ f"send_key_rotated_email_event: {json.dumps(send_key_rotated_email_event, indent=4, default=str)}"
131
+ )
132
+
133
+ # Check if API key should be included in email
134
+ include_api_key = get_secret_bool(secret_name="EMAIL_INCLUDE_API_KEY", default_value=True)
135
+ if include_api_key is None:
136
+ include_api_key = True # Default to True if not set
137
+ key_token_display = send_key_rotated_email_event.virtual_key if include_api_key else "[Key hidden for security - retrieve from dashboard]"
138
+
139
+ email_html_content = KEY_ROTATED_EMAIL_TEMPLATE.format(
140
+ email_logo_url=email_params.logo_url,
141
+ recipient_email=email_params.recipient_email,
142
+ key_budget=self._format_key_budget(send_key_rotated_email_event.max_budget),
143
+ key_token=key_token_display,
91
144
  base_url=email_params.base_url,
92
145
  email_support_contact=email_params.support_contact,
93
146
  email_footer=email_params.signature,
@@ -159,6 +212,13 @@ class BaseEmailLogger(CustomLogger):
159
212
  self.DEFAULT_SUBJECT_TEMPLATES[EmailEvent.virtual_key_created],
160
213
  "key created subject template"
161
214
  )
215
+ elif email_event == EmailEvent.virtual_key_rotated:
216
+ custom_subject_key_rotated = os.getenv("EMAIL_SUBJECT_KEY_ROTATED", None)
217
+ subject_template = get_custom_or_default(
218
+ custom_subject_key_rotated,
219
+ self.DEFAULT_SUBJECT_TEMPLATES[EmailEvent.virtual_key_rotated],
220
+ "key rotated subject template"
221
+ )
162
222
  else:
163
223
  subject_template = "LiteLLM: {event_message}"
164
224
 
@@ -29,11 +29,10 @@ class EnterpriseCustomGuardrailHelper:
29
29
  if event_hook is None or not isinstance(event_hook, Mode):
30
30
  return None
31
31
 
32
- metadata: dict = data.get("litellm_metadata") or data.get("metadata", {})
33
32
  proxy_server_request = data.get("proxy_server_request", {})
34
33
 
35
34
  request_tags = StandardLoggingPayloadSetup._get_request_tags(
36
- metadata=metadata,
35
+ litellm_params=data,
37
36
  proxy_server_request=proxy_server_request,
38
37
  )
39
38
 
@@ -2,7 +2,7 @@
2
2
  Polls LiteLLM_ManagedObjectTable to check if the batch job is complete, and if the cost has been tracked.
3
3
  """
4
4
 
5
- import uuid
5
+ from litellm._uuid import uuid
6
6
  from datetime import datetime
7
7
  from typing import TYPE_CHECKING, Optional, cast
8
8
 
@@ -57,7 +57,6 @@ class CheckBatchCost:
57
57
  "file_purpose": "batch",
58
58
  }
59
59
  )
60
-
61
60
  completed_jobs = []
62
61
 
63
62
  for job in jobs:
@@ -139,7 +138,7 @@ class CheckBatchCost:
139
138
  custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
140
139
  litellm_model_name = deployment_info.litellm_params.model
141
140
 
142
- _, llm_provider, _, _ = get_llm_provider(
141
+ model_name, llm_provider, _, _ = get_llm_provider(
143
142
  model=litellm_model_name,
144
143
  custom_llm_provider=custom_llm_provider,
145
144
  )
@@ -148,9 +147,9 @@ class CheckBatchCost:
148
147
  await calculate_batch_cost_and_usage(
149
148
  file_content_dictionary=file_content_as_dict,
150
149
  custom_llm_provider=llm_provider, # type: ignore
150
+ model_name=model_name,
151
151
  )
152
152
  )
153
-
154
153
  logging_obj = LiteLLMLogging(
155
154
  model=batch_models[0],
156
155
  messages=[{"role": "user", "content": "<retrieve_batch>"}],
@@ -4,12 +4,12 @@
4
4
  import asyncio
5
5
  import base64
6
6
  import json
7
- import uuid
8
7
  from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
9
8
 
10
9
  from fastapi import HTTPException
11
10
 
12
11
  from litellm import Router, verbose_logger
12
+ from litellm._uuid import uuid
13
13
  from litellm.caching.caching import DualCache
14
14
  from litellm.integrations.custom_logger import CustomLogger
15
15
  from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
@@ -36,6 +36,7 @@ from litellm.types.llms.openai import (
36
36
  OpenAIFilesPurpose,
37
37
  )
38
38
  from litellm.types.utils import (
39
+ CallTypesLiteral,
39
40
  LiteLLMBatch,
40
41
  LiteLLMFineTuningJob,
41
42
  LLMResponseTypes,
@@ -152,7 +153,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
152
153
  "status": file_object.status,
153
154
  },
154
155
  "update": {}, # don't do anything if it already exists
155
- }
156
+ },
156
157
  )
157
158
 
158
159
  async def get_unified_file_id(
@@ -224,9 +225,10 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
224
225
  where={"unified_object_id": unified_object_id}
225
226
  )
226
227
  )
228
+
227
229
  if managed_object:
228
230
  return managed_object.created_by == user_id
229
- return False
231
+ return True # don't raise error if managed object is not found
230
232
 
231
233
  async def get_user_created_file_ids(
232
234
  self, user_api_key_dict: UserAPIKeyAuth, model_object_ids: List[str]
@@ -271,27 +273,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
271
273
  user_api_key_dict: UserAPIKeyAuth,
272
274
  cache: DualCache,
273
275
  data: Dict,
274
- call_type: Literal[
275
- "completion",
276
- "text_completion",
277
- "embeddings",
278
- "image_generation",
279
- "moderation",
280
- "audio_transcription",
281
- "pass_through_endpoint",
282
- "rerank",
283
- "acreate_batch",
284
- "aretrieve_batch",
285
- "acreate_file",
286
- "afile_list",
287
- "afile_delete",
288
- "afile_content",
289
- "acreate_fine_tuning_job",
290
- "aretrieve_fine_tuning_job",
291
- "alist_fine_tuning_jobs",
292
- "acancel_fine_tuning_job",
293
- "mcp_call",
294
- ],
276
+ call_type: CallTypesLiteral,
295
277
  ) -> Union[Exception, str, Dict, None]:
296
278
  """
297
279
  - Detect litellm_proxy/ file_id
@@ -314,6 +296,16 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
314
296
  file_ids, user_api_key_dict.parent_otel_span
315
297
  )
316
298
 
299
+ data["model_file_id_mapping"] = model_file_id_mapping
300
+ elif call_type == CallTypes.aresponses.value or call_type == CallTypes.responses.value:
301
+ # Handle managed files in responses API input
302
+ input_data = data.get("input")
303
+ if input_data:
304
+ file_ids = self.get_file_ids_from_responses_input(input_data)
305
+ if file_ids:
306
+ model_file_id_mapping = await self.get_model_file_id_mapping(
307
+ file_ids, user_api_key_dict.parent_otel_span
308
+ )
317
309
  data["model_file_id_mapping"] = model_file_id_mapping
318
310
  elif call_type == CallTypes.afile_content.value:
319
311
  retrieve_file_id = cast(Optional[str], data.get("file_id"))
@@ -471,6 +463,47 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
471
463
  file_ids.append(file_id)
472
464
  return file_ids
473
465
 
466
+ def get_file_ids_from_responses_input(
467
+ self, input: Union[str, List[Dict[str, Any]]]
468
+ ) -> List[str]:
469
+ """
470
+ Gets file ids from responses API input.
471
+
472
+ The input can be:
473
+ - A string (no files)
474
+ - A list of input items, where each item can have:
475
+ - type: "input_file" with file_id
476
+ - content: a list that can contain items with type: "input_file" and file_id
477
+ """
478
+ file_ids: List[str] = []
479
+
480
+ if isinstance(input, str):
481
+ return file_ids
482
+
483
+ if not isinstance(input, list):
484
+ return file_ids
485
+
486
+ for item in input:
487
+ if not isinstance(item, dict):
488
+ continue
489
+
490
+ # Check for direct input_file type
491
+ if item.get("type") == "input_file":
492
+ file_id = item.get("file_id")
493
+ if file_id:
494
+ file_ids.append(file_id)
495
+
496
+ # Check for input_file in content array
497
+ content = item.get("content")
498
+ if isinstance(content, list):
499
+ for content_item in content:
500
+ if isinstance(content_item, dict) and content_item.get("type") == "input_file":
501
+ file_id = content_item.get("file_id")
502
+ if file_id:
503
+ file_ids.append(file_id)
504
+
505
+ return file_ids
506
+
474
507
  async def get_model_file_id_mapping(
475
508
  self, file_ids: List[str], litellm_parent_otel_span: Span
476
509
  ) -> dict:
@@ -496,7 +529,6 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
496
529
  for file_id in file_ids:
497
530
  ## CHECK IF FILE ID IS MANAGED BY LITELM
498
531
  is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id)
499
-
500
532
  if is_base64_unified_file_id:
501
533
  litellm_managed_file_ids.append(file_id)
502
534
 
@@ -507,6 +539,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
507
539
  unified_file_object = await self.get_unified_file_id(
508
540
  file_id, litellm_parent_otel_span
509
541
  )
542
+
510
543
  if unified_file_object:
511
544
  file_id_mapping[file_id] = unified_file_object.model_mappings
512
545
 
@@ -782,18 +815,21 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
782
815
  llm_router: Router,
783
816
  **data: Dict,
784
817
  ) -> OpenAIFileObject:
785
- file_id = convert_b64_uid_to_unified_uid(file_id)
818
+
819
+ # file_id = convert_b64_uid_to_unified_uid(file_id)
786
820
  model_file_id_mapping = await self.get_model_file_id_mapping(
787
821
  [file_id], litellm_parent_otel_span
788
822
  )
823
+
789
824
  specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
790
825
  if specific_model_file_id_mapping:
791
- for model_id, file_id in specific_model_file_id_mapping.items():
792
- await llm_router.afile_delete(model=model_id, file_id=file_id, **data) # type: ignore
826
+ for model_id, model_file_id in specific_model_file_id_mapping.items():
827
+ await llm_router.afile_delete(model=model_id, file_id=model_file_id, **data) # type: ignore
793
828
 
794
829
  stored_file_object = await self.delete_unified_file_id(
795
830
  file_id, litellm_parent_otel_span
796
831
  )
832
+
797
833
  if stored_file_object:
798
834
  return stored_file_object
799
835
  else:
@@ -814,6 +850,7 @@ class _PROXY_LiteLLMManagedFiles(CustomLogger, BaseFileEndpoints):
814
850
  model_file_id_mapping
815
851
  or await self.get_model_file_id_mapping([file_id], litellm_parent_otel_span)
816
852
  )
853
+
817
854
  specific_model_file_id_mapping = model_file_id_mapping.get(file_id)
818
855
 
819
856
  if specific_model_file_id_mapping:
@@ -2,7 +2,6 @@
2
2
  Enterprise internal user management endpoints
3
3
  """
4
4
 
5
- import os
6
5
 
7
6
  from fastapi import APIRouter, Depends, HTTPException
8
7
 
@@ -22,9 +22,21 @@ def add_team_member_key_duration(
22
22
  return data
23
23
 
24
24
 
25
+ def add_team_organization_id(
26
+ team_table: Optional[LiteLLM_TeamTable],
27
+ data: GenerateKeyRequest,
28
+ ) -> GenerateKeyRequest:
29
+ if team_table is None:
30
+ return data
31
+ setattr(data, "organization_id", team_table.organization_id)
32
+ return data
33
+
34
+
25
35
  def apply_enterprise_key_management_params(
26
36
  data: GenerateKeyRequest,
27
37
  team_table: Optional[LiteLLM_TeamTable],
28
38
  ) -> GenerateKeyRequest:
39
+
29
40
  data = add_team_member_key_duration(team_table, data)
41
+ data = add_team_organization_id(team_table, data)
30
42
  return data
@@ -9,14 +9,19 @@ All /vector_store management endpoints
9
9
  """
10
10
 
11
11
  import copy
12
+ import json
12
13
  from typing import List, Optional
13
14
 
14
- from fastapi import APIRouter, Depends, HTTPException, Request, Response
15
+ from fastapi import APIRouter, Depends, HTTPException
15
16
 
16
17
  import litellm
17
18
  from litellm._logging import verbose_proxy_logger
18
19
  from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
19
- from litellm.proxy._types import UserAPIKeyAuth
20
+ from litellm.proxy._types import (
21
+ LiteLLM_ManagedVectorStoresTable,
22
+ ResponseLiteLLM_ManagedVectorStore,
23
+ UserAPIKeyAuth,
24
+ )
20
25
  from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
21
26
  from litellm.types.vector_stores import (
22
27
  LiteLLM_ManagedVectorStore,
@@ -29,6 +34,7 @@ from litellm.vector_stores.vector_store_registry import VectorStoreRegistry
29
34
 
30
35
  router = APIRouter()
31
36
 
37
+
32
38
  ########################################################
33
39
  # Management Endpoints
34
40
  ########################################################
@@ -79,7 +85,9 @@ async def new_vector_store(
79
85
  litellm_params_json: Optional[str] = None
80
86
  _input_litellm_params: dict = vector_store.get("litellm_params", {}) or {}
81
87
  if _input_litellm_params is not None:
82
- litellm_params_dict = GenericLiteLLMParams(**_input_litellm_params).model_dump(exclude_none=True)
88
+ litellm_params_dict = GenericLiteLLMParams(
89
+ **_input_litellm_params
90
+ ).model_dump(exclude_none=True)
83
91
  litellm_params_json = safe_dumps(litellm_params_dict)
84
92
  del vector_store["litellm_params"]
85
93
 
@@ -227,6 +235,7 @@ async def delete_vector_store(
227
235
  "/vector_store/info",
228
236
  tags=["vector store management"],
229
237
  dependencies=[Depends(user_api_key_auth)],
238
+ response_model=ResponseLiteLLM_ManagedVectorStore,
230
239
  )
231
240
  async def get_vector_store_info(
232
241
  data: VectorStoreInfoRequest,
@@ -239,8 +248,39 @@ async def get_vector_store_info(
239
248
  raise HTTPException(status_code=500, detail="Database not connected")
240
249
 
241
250
  try:
242
- vector_store = await prisma_client.db.litellm_managedvectorstorestable.find_unique(
243
- where={"vector_store_id": data.vector_store_id}
251
+ if litellm.vector_store_registry is not None:
252
+ vector_store = litellm.vector_store_registry.get_litellm_managed_vector_store_from_registry(
253
+ vector_store_id=data.vector_store_id
254
+ )
255
+ if vector_store is not None:
256
+ vector_store_metadata = vector_store.get("vector_store_metadata")
257
+ # Parse metadata if it's a JSON string
258
+ parsed_metadata: Optional[dict] = None
259
+ if isinstance(vector_store_metadata, str):
260
+ parsed_metadata = json.loads(vector_store_metadata)
261
+ elif isinstance(vector_store_metadata, dict):
262
+ parsed_metadata = vector_store_metadata
263
+
264
+ vector_store_pydantic_obj = LiteLLM_ManagedVectorStoresTable(
265
+ vector_store_id=vector_store.get("vector_store_id") or "",
266
+ custom_llm_provider=vector_store.get("custom_llm_provider") or "",
267
+ vector_store_name=vector_store.get("vector_store_name") or None,
268
+ vector_store_description=vector_store.get(
269
+ "vector_store_description"
270
+ )
271
+ or None,
272
+ vector_store_metadata=parsed_metadata,
273
+ created_at=vector_store.get("created_at") or None,
274
+ updated_at=vector_store.get("updated_at") or None,
275
+ litellm_credential_name=vector_store.get("litellm_credential_name"),
276
+ litellm_params=vector_store.get("litellm_params") or None,
277
+ )
278
+ return {"vector_store": vector_store_pydantic_obj}
279
+
280
+ vector_store = (
281
+ await prisma_client.db.litellm_managedvectorstorestable.find_unique(
282
+ where={"vector_store_id": data.vector_store_id}
283
+ )
244
284
  )
245
285
  if vector_store is None:
246
286
  raise HTTPException(
@@ -248,7 +288,7 @@ async def get_vector_store_info(
248
288
  detail=f"Vector store with ID {data.vector_store_id} not found",
249
289
  )
250
290
 
251
- vector_store_dict = vector_store.model_dump()
291
+ vector_store_dict = vector_store.model_dump() # type: ignore[attr-defined]
252
292
  return {"vector_store": vector_store_dict}
253
293
  except Exception as e:
254
294
  verbose_proxy_logger.exception(f"Error getting vector store info: {str(e)}")
@@ -274,7 +314,9 @@ async def update_vector_store(
274
314
  update_data = data.model_dump(exclude_unset=True)
275
315
  vector_store_id = update_data.pop("vector_store_id")
276
316
  if update_data.get("vector_store_metadata") is not None:
277
- update_data["vector_store_metadata"] = safe_dumps(update_data["vector_store_metadata"])
317
+ update_data["vector_store_metadata"] = safe_dumps(
318
+ update_data["vector_store_metadata"]
319
+ )
278
320
 
279
321
  updated = await prisma_client.db.litellm_managedvectorstorestable.update(
280
322
  where={"vector_store_id": vector_store_id},
@@ -1,10 +1,11 @@
1
1
  import enum
2
- from typing import Dict, List
2
+ from typing import Dict, List, Optional
3
3
 
4
4
  from pydantic import BaseModel, Field
5
5
 
6
6
  from litellm.proxy._types import WebhookEvent
7
7
 
8
+
8
9
  class EmailParams(BaseModel):
9
10
  logo_url: str
10
11
  support_contact: str
@@ -22,9 +23,19 @@ class SendKeyCreatedEmailEvent(WebhookEvent):
22
23
  """
23
24
 
24
25
 
26
+ class SendKeyRotatedEmailEvent(WebhookEvent):
27
+ virtual_key: str
28
+ key_alias: Optional[str] = None
29
+ """
30
+ The virtual key that was rotated
31
+ this will be sk-123xxx, since we will be emailing this to the user to start using the new key
32
+ """
33
+
34
+
25
35
  class EmailEvent(str, enum.Enum):
26
36
  virtual_key_created = "Virtual Key Created"
27
37
  new_user_invitation = "New User Invitation"
38
+ virtual_key_rotated = "Virtual Key Rotated"
28
39
 
29
40
  class EmailEventSettings(BaseModel):
30
41
  event: EmailEvent
@@ -37,8 +48,9 @@ class DefaultEmailSettings(BaseModel):
37
48
  """Default settings for email events"""
38
49
  settings: Dict[EmailEvent, bool] = Field(
39
50
  default_factory=lambda: {
40
- EmailEvent.virtual_key_created: False, # Off by default
51
+ EmailEvent.virtual_key_created: True, # On by default
41
52
  EmailEvent.new_user_invitation: True, # On by default
53
+ EmailEvent.virtual_key_rotated: True, # On by default
42
54
  }
43
55
  )
44
56
  def to_dict(self) -> Dict[str, bool]: