mlrun 1.10.0rc13__py3-none-any.whl → 1.10.0rc15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (47) hide show
  1. mlrun/artifacts/base.py +0 -31
  2. mlrun/artifacts/llm_prompt.py +106 -20
  3. mlrun/artifacts/manager.py +0 -5
  4. mlrun/common/constants.py +0 -1
  5. mlrun/common/schemas/__init__.py +1 -0
  6. mlrun/common/schemas/model_monitoring/__init__.py +1 -0
  7. mlrun/common/schemas/model_monitoring/functions.py +1 -1
  8. mlrun/common/schemas/model_monitoring/model_endpoints.py +10 -0
  9. mlrun/common/schemas/workflow.py +0 -1
  10. mlrun/config.py +1 -1
  11. mlrun/datastore/model_provider/model_provider.py +42 -14
  12. mlrun/datastore/model_provider/openai_provider.py +96 -15
  13. mlrun/db/base.py +14 -0
  14. mlrun/db/httpdb.py +42 -9
  15. mlrun/db/nopdb.py +8 -0
  16. mlrun/execution.py +16 -7
  17. mlrun/model.py +15 -0
  18. mlrun/model_monitoring/__init__.py +1 -0
  19. mlrun/model_monitoring/applications/base.py +176 -20
  20. mlrun/model_monitoring/db/_schedules.py +84 -24
  21. mlrun/model_monitoring/db/tsdb/base.py +72 -1
  22. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +7 -1
  23. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +37 -0
  24. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +25 -0
  25. mlrun/model_monitoring/helpers.py +26 -4
  26. mlrun/projects/project.py +38 -12
  27. mlrun/runtimes/daskjob.py +6 -0
  28. mlrun/runtimes/mpijob/abstract.py +6 -0
  29. mlrun/runtimes/mpijob/v1.py +6 -0
  30. mlrun/runtimes/nuclio/application/application.py +2 -0
  31. mlrun/runtimes/nuclio/function.py +6 -0
  32. mlrun/runtimes/nuclio/serving.py +12 -11
  33. mlrun/runtimes/pod.py +21 -0
  34. mlrun/runtimes/remotesparkjob.py +6 -0
  35. mlrun/runtimes/sparkjob/spark3job.py +6 -0
  36. mlrun/serving/__init__.py +2 -0
  37. mlrun/serving/server.py +95 -26
  38. mlrun/serving/states.py +130 -10
  39. mlrun/utils/helpers.py +36 -12
  40. mlrun/utils/retryer.py +15 -2
  41. mlrun/utils/version/version.json +2 -2
  42. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/METADATA +3 -8
  43. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/RECORD +47 -47
  44. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/WHEEL +0 -0
  45. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/entry_points.txt +0 -0
  46. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/licenses/LICENSE +0 -0
  47. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/top_level.txt +0 -0
mlrun/artifacts/base.py CHANGED
@@ -16,7 +16,6 @@ import os
16
16
  import pathlib
17
17
  import tempfile
18
18
  import typing
19
- import warnings
20
19
  import zipfile
21
20
 
22
21
  import yaml
@@ -876,36 +875,6 @@ def generate_target_path(item: Artifact, artifact_path, producer):
876
875
  return f"{artifact_path}{item.key}{suffix}"
877
876
 
878
877
 
879
- # TODO: Remove once data migration v5 is obsolete
880
- def convert_legacy_artifact_to_new_format(
881
- legacy_artifact: dict,
882
- ) -> Artifact:
883
- """Converts a legacy artifact to a new format.
884
- :param legacy_artifact: The legacy artifact to convert.
885
- :return: The converted artifact.
886
- """
887
- artifact_key = legacy_artifact.get("key", "")
888
- artifact_tag = legacy_artifact.get("tag", "")
889
- if artifact_tag:
890
- artifact_key = f"{artifact_key}:{artifact_tag}"
891
- # TODO: Remove once data migration v5 is obsolete
892
- warnings.warn(
893
- f"Converting legacy artifact '{artifact_key}' to new format. This will not be supported in MLRun 1.10.0. "
894
- f"Make sure to save the artifact/project in the new format.",
895
- FutureWarning,
896
- )
897
-
898
- artifact = mlrun.artifacts.artifact_types.get(
899
- legacy_artifact.get("kind", "artifact"), mlrun.artifacts.Artifact
900
- )()
901
-
902
- artifact.metadata = artifact.metadata.from_dict(legacy_artifact)
903
- artifact.spec = artifact.spec.from_dict(legacy_artifact)
904
- artifact.status = artifact.status.from_dict(legacy_artifact)
905
-
906
- return artifact
907
-
908
-
909
878
  def fill_artifact_object_hash(object_dict, iteration=None, producer_id=None):
910
879
  # remove artifact related fields before calculating hash
911
880
  object_dict.setdefault("metadata", {})
@@ -11,12 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import json
14
15
  import tempfile
15
16
  from typing import Optional, Union
16
17
 
17
18
  import mlrun
18
19
  import mlrun.artifacts.model as model_art
19
- import mlrun.common
20
+ import mlrun.common.schemas
20
21
  from mlrun.artifacts import Artifact, ArtifactMetadata, ArtifactSpec
21
22
  from mlrun.utils import StorePrefix, logger
22
23
 
@@ -25,16 +26,18 @@ MAX_PROMPT_LENGTH = 1024
25
26
 
26
27
  class LLMPromptArtifactSpec(ArtifactSpec):
27
28
  _dict_fields = ArtifactSpec._dict_fields + [
28
- "prompt_string",
29
+ "prompt_template",
29
30
  "prompt_legend",
30
31
  "model_configuration",
31
32
  "description",
32
33
  ]
34
+ PROMPT_TEMPLATE_KEYS = ("content", "role")
35
+ PROMPT_LEGENDS_KEYS = ("field", "description")
33
36
 
34
37
  def __init__(
35
38
  self,
36
39
  model_artifact: Union[model_art.ModelArtifact, str] = None,
37
- prompt_string: Optional[str] = None,
40
+ prompt_template: Optional[list[dict]] = None,
38
41
  prompt_path: Optional[str] = None,
39
42
  prompt_legend: Optional[dict] = None,
40
43
  model_configuration: Optional[dict] = None,
@@ -42,22 +45,26 @@ class LLMPromptArtifactSpec(ArtifactSpec):
42
45
  target_path: Optional[str] = None,
43
46
  **kwargs,
44
47
  ):
45
- if prompt_string and prompt_path:
48
+ if prompt_template and prompt_path:
46
49
  raise mlrun.errors.MLRunInvalidArgumentError(
47
- "Cannot specify both 'prompt_string' and 'prompt_path'"
50
+ "Cannot specify both 'prompt_template' and 'prompt_path'"
48
51
  )
49
-
52
+ if prompt_legend:
53
+ self._verify_prompt_legend(prompt_legend)
54
+ if prompt_path:
55
+ self._verify_prompt_path(prompt_path)
56
+ if prompt_template:
57
+ self._verify_prompt_template(prompt_template)
50
58
  super().__init__(
51
59
  src_path=prompt_path,
52
60
  target_path=target_path,
53
61
  parent_uri=model_artifact.uri
54
62
  if isinstance(model_artifact, model_art.ModelArtifact)
55
63
  else model_artifact,
56
- body=prompt_string,
57
64
  **kwargs,
58
65
  )
59
66
 
60
- self.prompt_string = prompt_string
67
+ self.prompt_template = prompt_template
61
68
  self.prompt_legend = prompt_legend
62
69
  self.model_configuration = model_configuration
63
70
  self.description = description
@@ -67,10 +74,78 @@ class LLMPromptArtifactSpec(ArtifactSpec):
67
74
  else None
68
75
  )
69
76
 
77
+ def _verify_prompt_template(self, prompt_template):
78
+ if not (
79
+ isinstance(prompt_template, list)
80
+ and all(isinstance(item, dict) for item in prompt_template)
81
+ ):
82
+ raise mlrun.errors.MLRunInvalidArgumentError(
83
+ "Expected prompt_template to be a list of dicts"
84
+ )
85
+ keys_to_pop = []
86
+ for message in prompt_template:
87
+ for key in message.keys():
88
+ if isinstance(key, str):
89
+ if key.lower() not in self.PROMPT_TEMPLATE_KEYS:
90
+ raise mlrun.errors.MLRunInvalidArgumentError(
91
+ f"Expected prompt_template to contain dict that "
92
+ f"only has keys from {self.PROMPT_TEMPLATE_KEYS}"
93
+ )
94
+ else:
95
+ if not key.islower():
96
+ message[key.lower()] = message[key]
97
+ keys_to_pop.append(key)
98
+ else:
99
+ raise mlrun.errors.MLRunInvalidArgumentError(
100
+ f"Expected prompt_template to contain dict that only"
101
+ f" has str keys got {key} of type {type(key)}"
102
+ )
103
+ for key_to_pop in keys_to_pop:
104
+ message.pop(key_to_pop)
105
+
70
106
  @property
71
107
  def model_uri(self):
72
108
  return self.parent_uri
73
109
 
110
+ @staticmethod
111
+ def _verify_prompt_legend(prompt_legend: dict):
112
+ if prompt_legend is None:
113
+ return True
114
+ for place_holder, body_map in prompt_legend.items():
115
+ if isinstance(body_map, dict):
116
+ if body_map.get("field") is None:
117
+ body_map["field"] = place_holder
118
+ body_map["description"] = body_map.get("description")
119
+ if diff := set(body_map.keys()) - set(
120
+ LLMPromptArtifactSpec.PROMPT_LEGENDS_KEYS
121
+ ):
122
+ raise mlrun.errors.MLRunInvalidArgumentError(
123
+ "prompt_legend values must contain only 'field' and "
124
+ f"'description' keys, got extra fields: {diff}"
125
+ )
126
+ else:
127
+ raise mlrun.errors.MLRunInvalidArgumentError(
128
+ f"Wrong prompt_legend format, {place_holder} is not mapped to dict"
129
+ )
130
+
131
+ @staticmethod
132
+ def _verify_prompt_path(prompt_path: str):
133
+ with mlrun.datastore.store_manager.object(prompt_path).open(mode="r") as p_file:
134
+ try:
135
+ json.load(p_file)
136
+ except json.JSONDecodeError:
137
+ raise mlrun.errors.MLRunInvalidArgumentError(
138
+ f"Failed on decoding str in path "
139
+ f"{prompt_path} expected file to contain a "
140
+ f"json format."
141
+ )
142
+
143
+ def get_body(self):
144
+ if self.prompt_template:
145
+ return json.dumps(self.prompt_template)
146
+ else:
147
+ return None
148
+
74
149
 
75
150
  class LLMPromptArtifact(Artifact):
76
151
  """
@@ -90,7 +165,7 @@ class LLMPromptArtifact(Artifact):
90
165
  model_artifact: Union[
91
166
  model_art.ModelArtifact, str
92
167
  ] = None, # TODO support partial model uri
93
- prompt_string: Optional[str] = None,
168
+ prompt_template: Optional[list[dict]] = None,
94
169
  prompt_path: Optional[str] = None,
95
170
  prompt_legend: Optional[dict] = None,
96
171
  model_configuration: Optional[dict] = None,
@@ -99,7 +174,7 @@ class LLMPromptArtifact(Artifact):
99
174
  **kwargs,
100
175
  ):
101
176
  llm_prompt_spec = LLMPromptArtifactSpec(
102
- prompt_string=prompt_string,
177
+ prompt_template=prompt_template,
103
178
  prompt_path=prompt_path,
104
179
  prompt_legend=prompt_legend,
105
180
  model_artifact=model_artifact,
@@ -137,33 +212,44 @@ class LLMPromptArtifact(Artifact):
137
212
  return self.spec._model_artifact
138
213
  return None
139
214
 
140
- def read_prompt(self) -> Optional[str]:
215
+ def read_prompt(self) -> Optional[Union[str, list[dict]]]:
141
216
  """
142
- Read the prompt string from the artifact.
217
+ Read the prompt json from the artifact or if provided prompt template.
218
+ @:param as_str: True to return the prompt string or a list of dicts.
219
+ @:return prompt string or list of dicts
143
220
  """
144
- if self.spec.prompt_string:
145
- return self.spec.prompt_string
221
+ if self.spec.prompt_template:
222
+ return self.spec.prompt_template
146
223
  if self.spec.target_path:
147
224
  with mlrun.datastore.store_manager.object(url=self.spec.target_path).open(
148
225
  mode="r"
149
226
  ) as p_file:
150
- return p_file.read()
227
+ try:
228
+ return json.load(p_file)
229
+ except json.JSONDecodeError:
230
+ raise mlrun.errors.MLRunInvalidArgumentError(
231
+ f"Failed on decoding str in path "
232
+ f"{self.spec.target_path} expected file to contain a "
233
+ f"json format."
234
+ )
151
235
 
152
236
  def before_log(self):
153
237
  """
154
238
  Prepare the artifact before logging.
155
239
  This method is called before the artifact is logged.
156
240
  """
157
- if self.spec.prompt_string and len(self.spec.prompt_string) > MAX_PROMPT_LENGTH:
241
+ if (
242
+ self.spec.prompt_template
243
+ and len(str(self.spec.prompt_template)) > MAX_PROMPT_LENGTH
244
+ ):
158
245
  logger.debug(
159
246
  "Prompt string exceeds maximum length, saving to a temporary file."
160
247
  )
161
248
  with tempfile.NamedTemporaryFile(
162
- delete=False, mode="w", suffix=".txt"
249
+ delete=False, mode="w", suffix=".json"
163
250
  ) as temp_file:
164
- temp_file.write(self.spec.prompt_string)
251
+ temp_file.write(json.dumps(self.spec.prompt_template))
165
252
  self.spec.src_path = temp_file.name
166
- self.spec.prompt_string = None
253
+ self.spec.prompt_template = None
167
254
  self._src_is_temp = True
168
-
169
255
  super().before_log()
@@ -110,11 +110,6 @@ class ArtifactProducer:
110
110
 
111
111
  def dict_to_artifact(struct: dict) -> Artifact:
112
112
  kind = struct.get("kind", "")
113
-
114
- # TODO: Remove once data migration v5 is obsolete
115
- if mlrun.utils.is_legacy_artifact(struct):
116
- return mlrun.artifacts.base.convert_legacy_artifact_to_new_format(struct)
117
-
118
113
  artifact_class = artifact_types[kind]
119
114
  return artifact_class.from_dict(struct)
120
115
 
mlrun/common/constants.py CHANGED
@@ -81,7 +81,6 @@ class MLRunInternalLabels:
81
81
  kind = "kind"
82
82
  component = "component"
83
83
  mlrun_type = "mlrun__type"
84
- rerun_of = "rerun-of"
85
84
  original_workflow_id = "original-workflow-id"
86
85
  workflow_id = "workflow-id"
87
86
 
@@ -147,6 +147,7 @@ from .model_monitoring import (
147
147
  GrafanaTable,
148
148
  ModelEndpoint,
149
149
  ModelEndpointCreationStrategy,
150
+ ModelEndpointDriftValues,
150
151
  ModelEndpointList,
151
152
  ModelEndpointMetadata,
152
153
  ModelEndpointSchema,
@@ -59,6 +59,7 @@ from .model_endpoints import (
59
59
  Features,
60
60
  FeatureValues,
61
61
  ModelEndpoint,
62
+ ModelEndpointDriftValues,
62
63
  ModelEndpointList,
63
64
  ModelEndpointMetadata,
64
65
  ModelEndpointMonitoringMetric,
@@ -64,5 +64,5 @@ class FunctionSummary(BaseModel):
64
64
  updated_time=func_dict["metadata"].get("updated"),
65
65
  status=func_dict["status"].get("state"),
66
66
  base_period=base_period,
67
- stats=stats,
67
+ stats=stats or {},
68
68
  )
@@ -352,6 +352,16 @@ class ApplicationMetricRecord(ApplicationBaseRecord):
352
352
  type: Literal["metric"] = "metric"
353
353
 
354
354
 
355
+ class _DriftBin(NamedTuple):
356
+ timestamp: datetime
357
+ count_suspected: int
358
+ count_detected: int
359
+
360
+
361
+ class ModelEndpointDriftValues(BaseModel):
362
+ values: list[_DriftBin]
363
+
364
+
355
365
  def _mapping_attributes(
356
366
  model_class: type[Model],
357
367
  flattened_dictionary: dict,
@@ -49,7 +49,6 @@ class WorkflowRequest(pydantic.v1.BaseModel):
49
49
  class RerunWorkflowRequest(pydantic.v1.BaseModel):
50
50
  run_name: typing.Optional[str] = None
51
51
  run_id: typing.Optional[str] = None
52
- original_workflow_id: typing.Optional[str] = None
53
52
  notifications: typing.Optional[list[Notification]] = None
54
53
  workflow_runner_node_selector: typing.Optional[dict[str, str]] = None
55
54
 
mlrun/config.py CHANGED
@@ -193,7 +193,7 @@ default_config = {
193
193
  },
194
194
  "v3io_framesd": "http://framesd:8080",
195
195
  "model_providers": {
196
- "openai_default_model": "gpt-4",
196
+ "openai_default_model": "gpt-4o",
197
197
  },
198
198
  # default node selector to be applied to all functions - json string base64 encoded format
199
199
  "default_function_node_selector": "e30=",
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  from collections.abc import Awaitable
15
- from typing import Callable, Optional, TypeVar, Union
15
+ from typing import Any, Callable, Optional, TypeVar, Union
16
16
 
17
17
  import mlrun.errors
18
18
  from mlrun.datastore.remote_client import (
@@ -56,9 +56,16 @@ class ModelProvider(BaseRemoteClient):
56
56
  )
57
57
  self.default_invoke_kwargs = default_invoke_kwargs or {}
58
58
  self._client = None
59
- self._default_operation = None
60
59
  self._async_client = None
61
- self._default_async_operation = None
60
+
61
+ def get_client_options(self) -> dict:
62
+ """
63
+ Returns a dictionary containing credentials and configuration
64
+ options required for client creation.
65
+
66
+ :return: A dictionary with client-specific settings.
67
+ """
68
+ return {}
62
69
 
63
70
  def load_client(self) -> None:
64
71
  """
@@ -68,8 +75,6 @@ class ModelProvider(BaseRemoteClient):
68
75
  Subclasses should override this method to:
69
76
  - Create and configure the provider-specific client instance.
70
77
  - Assign the client instance to self._client.
71
- - Define a default operation callable (e.g., a method to invoke model completions)
72
- and assign it to self._default_operation.
73
78
  """
74
79
 
75
80
  raise NotImplementedError("load_client method is not implemented")
@@ -122,39 +127,62 @@ class ModelProvider(BaseRemoteClient):
122
127
  """
123
128
  raise NotImplementedError("invoke method is not implemented")
124
129
 
125
- def customized_invoke(
130
+ def custom_invoke(
126
131
  self, operation: Optional[Callable[..., T]] = None, **invoke_kwargs
127
132
  ) -> Optional[T]:
128
- raise NotImplementedError("customized_invoke method is not implemented")
133
+ """
134
+ Invokes a model operation from a provider (e.g., OpenAI, Hugging Face, etc.) with the given keyword arguments.
135
+
136
+ Useful for dynamically calling model methods like text generation, chat completions, or image generation.
137
+ The operation must be a callable that accepts keyword arguments.
138
+
139
+ :param operation: A callable representing the model operation (e.g., a client method).
140
+ :param invoke_kwargs: Keyword arguments to pass to the operation.
141
+ :return: The full response returned by the operation.
142
+ """
143
+ raise NotImplementedError("custom_invoke method is not implemented")
129
144
 
130
145
  @property
131
- def client(self):
146
+ def client(self) -> Any:
132
147
  return self._client
133
148
 
134
149
  @property
135
- def model(self):
150
+ def model(self) -> Optional[str]:
136
151
  return None
137
152
 
138
- def get_invoke_kwargs(self, invoke_kwargs):
153
+ def get_invoke_kwargs(self, invoke_kwargs) -> dict:
139
154
  kwargs = self.default_invoke_kwargs.copy()
140
155
  kwargs.update(invoke_kwargs)
141
156
  return kwargs
142
157
 
143
158
  @property
144
- def async_client(self):
159
+ def async_client(self) -> Any:
145
160
  if not self.support_async:
146
161
  raise mlrun.errors.MLRunInvalidArgumentError(
147
162
  f"{self.__class__.__name__} does not support async operations"
148
163
  )
149
164
  return self._async_client
150
165
 
151
- async def async_customized_invoke(self, **kwargs):
152
- raise NotImplementedError("async_customized_invoke is not implemented")
166
+ async def async_custom_invoke(
167
+ self, operation: Optional[Callable[..., Awaitable[T]]], **invoke_kwargs
168
+ ) -> Optional[T]:
169
+ """
170
+ Asynchronously invokes a model operation from a provider (e.g., OpenAI, Hugging Face, etc.)
171
+ with the given keyword arguments.
172
+
173
+ The operation must be an async callable (e.g., a method from an async client) that accepts keyword arguments.
174
+
175
+ :param operation: An async callable representing the model operation (e.g., an async_client method).
176
+ :param invoke_kwargs: Keyword arguments to pass to the operation.
177
+ :return: The full response returned by the awaited operation.
178
+ """
179
+ raise NotImplementedError("async_custom_invoke is not implemented")
153
180
 
154
181
  async def async_invoke(
155
182
  self,
156
183
  messages: Optional[list[dict]] = None,
157
184
  as_str: bool = False,
158
185
  **invoke_kwargs,
159
- ) -> Awaitable[str]:
186
+ ) -> Optional[str]:
187
+ """Async version of `invoke`. See `invoke` for full documentation."""
160
188
  raise NotImplementedError("async_invoke is not implemented")
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
14
+ from collections.abc import Awaitable
15
15
  from typing import Callable, Optional, TypeVar, Union
16
16
 
17
17
  import mlrun
@@ -33,6 +33,8 @@ class OpenAIProvider(ModelProvider):
33
33
  operations tailored to the OpenAI API.
34
34
  """
35
35
 
36
+ support_async = True
37
+
36
38
  def __init__(
37
39
  self,
38
40
  parent,
@@ -67,7 +69,7 @@ class OpenAIProvider(ModelProvider):
67
69
  return endpoint, subpath
68
70
 
69
71
  @property
70
- def model(self):
72
+ def model(self) -> Optional[str]:
71
73
  return self.endpoint
72
74
 
73
75
  def load_client(self) -> None:
@@ -76,23 +78,20 @@ class OpenAIProvider(ModelProvider):
76
78
 
77
79
  This method imports the `OpenAI` class from the `openai` package, instantiates
78
80
  a client with the given keyword arguments (`self.options`), and assigns it to
79
- `self._client`.
80
-
81
- It also sets the default operation to `self.client.chat.completions.create`, which is
82
- typically used for invoking chat-based model completions.
81
+ `self._client` and `self._async_client`.
83
82
 
84
83
  Raises:
85
84
  ImportError: If the `openai` package is not installed.
86
85
  """
87
86
  try:
88
- from openai import OpenAI # noqa
87
+ from openai import OpenAI, AsyncOpenAI # noqa
89
88
 
90
89
  self._client = OpenAI(**self.options)
91
- self._default_operation = self.client.chat.completions.create
90
+ self._async_client = AsyncOpenAI(**self.options)
92
91
  except ImportError as exc:
93
92
  raise ImportError("openai package is not installed") from exc
94
93
 
95
- def get_client_options(self):
94
+ def get_client_options(self) -> dict:
96
95
  res = dict(
97
96
  api_key=self._get_secret_or_env("OPENAI_API_KEY"),
98
97
  organization=self._get_secret_or_env("OPENAI_ORG_ID"),
@@ -103,14 +102,69 @@ class OpenAIProvider(ModelProvider):
103
102
  )
104
103
  return self._sanitize_options(res)
105
104
 
106
- def customized_invoke(
105
+ def custom_invoke(
107
106
  self, operation: Optional[Callable[..., T]] = None, **invoke_kwargs
108
107
  ) -> Optional[T]:
108
+ """
109
+ OpenAI-specific implementation of `ModelProvider.custom_invoke`.
110
+
111
+ Invokes an OpenAI model operation using the sync client. For full details, see
112
+ `ModelProvider.custom_invoke`.
113
+
114
+ Example:
115
+ ```python
116
+ result = openai_model_provider.invoke(
117
+ openai_model_provider.client.images.generate,
118
+ prompt="A futuristic cityscape at sunset",
119
+ n=1,
120
+ size="1024x1024",
121
+ )
122
+ ```
123
+ :param operation: Same as ModelProvider.custom_invoke.
124
+ :param invoke_kwargs: Same as ModelProvider.custom_invoke.
125
+ :return: Same as ModelProvider.custom_invoke.
126
+
127
+ """
109
128
  invoke_kwargs = self.get_invoke_kwargs(invoke_kwargs)
110
129
  if operation:
111
130
  return operation(**invoke_kwargs, model=self.model)
112
131
  else:
113
- return self._default_operation(**invoke_kwargs, model=self.model)
132
+ return self.client.chat.completions.create(
133
+ **invoke_kwargs, model=self.model
134
+ )
135
+
136
+ async def async_custom_invoke(
137
+ self,
138
+ operation: Optional[Callable[..., Awaitable[T]]] = None,
139
+ **invoke_kwargs,
140
+ ) -> Optional[T]:
141
+ """
142
+ OpenAI-specific implementation of `ModelProvider.async_custom_invoke`.
143
+
144
+ Invokes an OpenAI model operation using the async client. For full details, see
145
+ `ModelProvider.async_custom_invoke`.
146
+
147
+ Example:
148
+ ```python
149
+ result = openai_model_provider.invoke(
150
+ openai_model_provider.async_client.images.generate,
151
+ prompt="A futuristic cityscape at sunset",
152
+ n=1,
153
+ size="1024x1024",
154
+ )
155
+ ```
156
+ :param operation: Same as ModelProvider.async_custom_invoke.
157
+ :param invoke_kwargs: Same as ModelProvider.async_custom_invoke.
158
+ :return: Same as ModelProvider.async_custom_invoke.
159
+
160
+ """
161
+ invoke_kwargs = self.get_invoke_kwargs(invoke_kwargs)
162
+ if operation:
163
+ return await operation(**invoke_kwargs, model=self.model)
164
+ else:
165
+ return await self.async_client.chat.completions.create(
166
+ **invoke_kwargs, model=self.model
167
+ )
114
168
 
115
169
  def invoke(
116
170
  self,
@@ -133,12 +187,39 @@ class OpenAIProvider(ModelProvider):
133
187
 
134
188
  :param invoke_kwargs:
135
189
  Same as ModelProvider.invoke.
190
+ :return: Same as ModelProvider.invoke.
136
191
 
137
192
  """
138
- invoke_kwargs = self.get_invoke_kwargs(invoke_kwargs)
139
- response = self._default_operation(
140
- model=self.endpoint, messages=messages, **invoke_kwargs
141
- )
193
+ response = self.custom_invoke(messages=messages, **invoke_kwargs)
194
+ if as_str:
195
+ return response.choices[0].message.content
196
+ return response
197
+
198
+ async def async_invoke(
199
+ self,
200
+ messages: Optional[list[dict]] = None,
201
+ as_str: bool = False,
202
+ **invoke_kwargs,
203
+ ) -> str:
204
+ """
205
+ OpenAI-specific implementation of `ModelProvider.async_invoke`.
206
+ Invokes an OpenAI model operation using the async client.
207
+ For full details, see `ModelProvider.async_invoke`.
208
+
209
+ :param messages: Same as ModelProvider.async_invoke.
210
+
211
+ :param as_str: bool
212
+ If `True`, returns only the main content of the first response
213
+ (`response.choices[0].message.content`).
214
+ If `False`, returns the full awaited response object, whose type depends on
215
+ the specific OpenAI SDK operation used (e.g., chat completion, completion, etc.).
216
+
217
+ :param invoke_kwargs:
218
+ Same as ModelProvider.async_invoke.
219
+ :returns Same as ModelProvider.async_invoke.
220
+
221
+ """
222
+ response = await self.async_custom_invoke(messages=messages, **invoke_kwargs)
142
223
  if as_str:
143
224
  return response.choices[0].message.content
144
225
  return response
mlrun/db/base.py CHANGED
@@ -638,6 +638,11 @@ class RunDBInterface(ABC):
638
638
  ):
639
639
  pass
640
640
 
641
+ def wait_for_background_task_to_reach_terminal_state(
642
+ self, name: str, project: str = ""
643
+ ) -> mlrun.common.schemas.BackgroundTask:
644
+ pass
645
+
641
646
  @abstractmethod
642
647
  def retry_pipeline(
643
648
  self,
@@ -1145,3 +1150,12 @@ class RunDBInterface(ABC):
1145
1150
  @abstractmethod
1146
1151
  def get_project_summary(self, project: str) -> mlrun.common.schemas.ProjectSummary:
1147
1152
  pass
1153
+
1154
+ @abstractmethod
1155
+ def get_drift_over_time(
1156
+ self,
1157
+ project: str,
1158
+ start: Optional[datetime.datetime] = None,
1159
+ end: Optional[datetime.datetime] = None,
1160
+ ) -> mlrun.common.schemas.model_monitoring.ModelEndpointDriftValues:
1161
+ pass