mlrun 1.10.0rc18__py3-none-any.whl → 1.11.0rc16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (167) hide show
  1. mlrun/__init__.py +24 -3
  2. mlrun/__main__.py +0 -4
  3. mlrun/artifacts/dataset.py +2 -2
  4. mlrun/artifacts/document.py +6 -1
  5. mlrun/artifacts/llm_prompt.py +21 -15
  6. mlrun/artifacts/model.py +3 -3
  7. mlrun/artifacts/plots.py +1 -1
  8. mlrun/{model_monitoring/db/tsdb/tdengine → auth}/__init__.py +2 -3
  9. mlrun/auth/nuclio.py +89 -0
  10. mlrun/auth/providers.py +429 -0
  11. mlrun/auth/utils.py +415 -0
  12. mlrun/common/constants.py +14 -0
  13. mlrun/common/model_monitoring/helpers.py +123 -0
  14. mlrun/common/runtimes/constants.py +28 -0
  15. mlrun/common/schemas/__init__.py +14 -3
  16. mlrun/common/schemas/alert.py +2 -2
  17. mlrun/common/schemas/api_gateway.py +3 -0
  18. mlrun/common/schemas/auth.py +12 -10
  19. mlrun/common/schemas/client_spec.py +4 -0
  20. mlrun/common/schemas/constants.py +25 -0
  21. mlrun/common/schemas/frontend_spec.py +1 -8
  22. mlrun/common/schemas/function.py +34 -0
  23. mlrun/common/schemas/hub.py +33 -20
  24. mlrun/common/schemas/model_monitoring/__init__.py +2 -1
  25. mlrun/common/schemas/model_monitoring/constants.py +12 -15
  26. mlrun/common/schemas/model_monitoring/functions.py +13 -4
  27. mlrun/common/schemas/model_monitoring/model_endpoints.py +11 -0
  28. mlrun/common/schemas/pipeline.py +1 -1
  29. mlrun/common/schemas/secret.py +17 -2
  30. mlrun/common/secrets.py +95 -1
  31. mlrun/common/types.py +10 -10
  32. mlrun/config.py +69 -19
  33. mlrun/data_types/infer.py +2 -2
  34. mlrun/datastore/__init__.py +12 -5
  35. mlrun/datastore/azure_blob.py +162 -47
  36. mlrun/datastore/base.py +274 -10
  37. mlrun/datastore/datastore.py +7 -2
  38. mlrun/datastore/datastore_profile.py +84 -22
  39. mlrun/datastore/model_provider/huggingface_provider.py +225 -41
  40. mlrun/datastore/model_provider/mock_model_provider.py +87 -0
  41. mlrun/datastore/model_provider/model_provider.py +206 -74
  42. mlrun/datastore/model_provider/openai_provider.py +226 -66
  43. mlrun/datastore/s3.py +39 -18
  44. mlrun/datastore/sources.py +1 -1
  45. mlrun/datastore/store_resources.py +4 -4
  46. mlrun/datastore/storeytargets.py +17 -12
  47. mlrun/datastore/targets.py +1 -1
  48. mlrun/datastore/utils.py +25 -6
  49. mlrun/datastore/v3io.py +1 -1
  50. mlrun/db/base.py +63 -32
  51. mlrun/db/httpdb.py +373 -153
  52. mlrun/db/nopdb.py +54 -21
  53. mlrun/errors.py +4 -2
  54. mlrun/execution.py +66 -25
  55. mlrun/feature_store/api.py +1 -1
  56. mlrun/feature_store/common.py +1 -1
  57. mlrun/feature_store/feature_vector_utils.py +1 -1
  58. mlrun/feature_store/steps.py +8 -6
  59. mlrun/frameworks/_common/utils.py +3 -3
  60. mlrun/frameworks/_dl_common/loggers/logger.py +1 -1
  61. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +2 -1
  62. mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +1 -1
  63. mlrun/frameworks/_ml_common/utils.py +2 -1
  64. mlrun/frameworks/auto_mlrun/auto_mlrun.py +4 -3
  65. mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +2 -1
  66. mlrun/frameworks/onnx/dataset.py +2 -1
  67. mlrun/frameworks/onnx/mlrun_interface.py +2 -1
  68. mlrun/frameworks/pytorch/callbacks/logging_callback.py +5 -4
  69. mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +2 -1
  70. mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +2 -1
  71. mlrun/frameworks/pytorch/utils.py +2 -1
  72. mlrun/frameworks/sklearn/metric.py +2 -1
  73. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +5 -4
  74. mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +2 -1
  75. mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +2 -1
  76. mlrun/hub/__init__.py +52 -0
  77. mlrun/hub/base.py +142 -0
  78. mlrun/hub/module.py +172 -0
  79. mlrun/hub/step.py +113 -0
  80. mlrun/k8s_utils.py +105 -16
  81. mlrun/launcher/base.py +15 -7
  82. mlrun/launcher/local.py +4 -1
  83. mlrun/model.py +14 -4
  84. mlrun/model_monitoring/__init__.py +0 -1
  85. mlrun/model_monitoring/api.py +65 -28
  86. mlrun/model_monitoring/applications/__init__.py +1 -1
  87. mlrun/model_monitoring/applications/base.py +299 -128
  88. mlrun/model_monitoring/applications/context.py +2 -4
  89. mlrun/model_monitoring/controller.py +132 -58
  90. mlrun/model_monitoring/db/_schedules.py +38 -29
  91. mlrun/model_monitoring/db/_stats.py +6 -16
  92. mlrun/model_monitoring/db/tsdb/__init__.py +9 -7
  93. mlrun/model_monitoring/db/tsdb/base.py +29 -9
  94. mlrun/model_monitoring/db/tsdb/preaggregate.py +234 -0
  95. mlrun/model_monitoring/db/tsdb/stream_graph_steps.py +63 -0
  96. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_metrics_queries.py +414 -0
  97. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_predictions_queries.py +376 -0
  98. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_results_queries.py +590 -0
  99. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connection.py +434 -0
  100. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connector.py +541 -0
  101. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_operations.py +808 -0
  102. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_schema.py +502 -0
  103. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream.py +163 -0
  104. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream_graph_steps.py +60 -0
  105. mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_dataframe_processor.py +141 -0
  106. mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_query_builder.py +585 -0
  107. mlrun/model_monitoring/db/tsdb/timescaledb/writer_graph_steps.py +73 -0
  108. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +20 -9
  109. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +235 -51
  110. mlrun/model_monitoring/features_drift_table.py +2 -1
  111. mlrun/model_monitoring/helpers.py +30 -6
  112. mlrun/model_monitoring/stream_processing.py +34 -28
  113. mlrun/model_monitoring/writer.py +224 -4
  114. mlrun/package/__init__.py +2 -1
  115. mlrun/platforms/__init__.py +0 -43
  116. mlrun/platforms/iguazio.py +8 -4
  117. mlrun/projects/operations.py +17 -11
  118. mlrun/projects/pipelines.py +2 -2
  119. mlrun/projects/project.py +187 -123
  120. mlrun/run.py +95 -21
  121. mlrun/runtimes/__init__.py +2 -186
  122. mlrun/runtimes/base.py +103 -25
  123. mlrun/runtimes/constants.py +225 -0
  124. mlrun/runtimes/daskjob.py +5 -2
  125. mlrun/runtimes/databricks_job/databricks_runtime.py +2 -1
  126. mlrun/runtimes/local.py +5 -2
  127. mlrun/runtimes/mounts.py +20 -2
  128. mlrun/runtimes/nuclio/__init__.py +12 -7
  129. mlrun/runtimes/nuclio/api_gateway.py +36 -6
  130. mlrun/runtimes/nuclio/application/application.py +339 -40
  131. mlrun/runtimes/nuclio/function.py +222 -72
  132. mlrun/runtimes/nuclio/serving.py +132 -42
  133. mlrun/runtimes/pod.py +213 -21
  134. mlrun/runtimes/utils.py +49 -9
  135. mlrun/secrets.py +99 -14
  136. mlrun/serving/__init__.py +2 -0
  137. mlrun/serving/remote.py +84 -11
  138. mlrun/serving/routers.py +26 -44
  139. mlrun/serving/server.py +138 -51
  140. mlrun/serving/serving_wrapper.py +6 -2
  141. mlrun/serving/states.py +997 -283
  142. mlrun/serving/steps.py +62 -0
  143. mlrun/serving/system_steps.py +149 -95
  144. mlrun/serving/v2_serving.py +9 -10
  145. mlrun/track/trackers/mlflow_tracker.py +29 -31
  146. mlrun/utils/helpers.py +292 -94
  147. mlrun/utils/http.py +9 -2
  148. mlrun/utils/notifications/notification/base.py +18 -0
  149. mlrun/utils/notifications/notification/git.py +3 -5
  150. mlrun/utils/notifications/notification/mail.py +39 -16
  151. mlrun/utils/notifications/notification/slack.py +2 -4
  152. mlrun/utils/notifications/notification/webhook.py +2 -5
  153. mlrun/utils/notifications/notification_pusher.py +3 -3
  154. mlrun/utils/version/version.json +2 -2
  155. mlrun/utils/version/version.py +3 -4
  156. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/METADATA +63 -74
  157. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/RECORD +161 -143
  158. mlrun/api/schemas/__init__.py +0 -259
  159. mlrun/db/auth_utils.py +0 -152
  160. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +0 -344
  161. mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +0 -75
  162. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connection.py +0 -281
  163. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +0 -1266
  164. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/WHEEL +0 -0
  165. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/entry_points.txt +0 -0
  166. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/licenses/LICENSE +0 -0
  167. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/top_level.txt +0 -0
@@ -11,17 +11,19 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- from typing import TYPE_CHECKING, Optional, TypeVar, Union
14
+ import threading
15
+ from typing import TYPE_CHECKING, Any, Optional, Union
16
16
 
17
17
  import mlrun
18
- from mlrun.datastore.model_provider.model_provider import ModelProvider
18
+ from mlrun.datastore.model_provider.model_provider import (
19
+ InvokeResponseFormat,
20
+ ModelProvider,
21
+ UsageResponseKeys,
22
+ )
19
23
 
20
24
  if TYPE_CHECKING:
21
25
  from transformers.pipelines.base import Pipeline
22
-
23
- T = TypeVar("T")
24
- ChatType = list[dict[str, str]] # according to transformers.pipelines.text_generation
26
+ from transformers.pipelines.text_generation import ChatType
25
27
 
26
28
 
27
29
  class HuggingFaceProvider(ModelProvider):
@@ -34,8 +36,14 @@ class HuggingFaceProvider(ModelProvider):
34
36
  This class extends the ModelProvider base class and implements Hugging Face-specific
35
37
  functionality, including pipeline initialization, default text generation operations,
36
38
  and custom operations tailored to the Hugging Face Transformers pipeline API.
39
+
40
+ Note: The pipeline object will download the model (if not already cached) and load it
41
+ into memory for inference. Ensure you have the required CPU/GPU and memory to use this operation.
37
42
  """
38
43
 
44
+ # locks for threading use cases
45
+ _client_lock = threading.Lock()
46
+
39
47
  def __init__(
40
48
  self,
41
49
  parent,
@@ -60,18 +68,20 @@ class HuggingFaceProvider(ModelProvider):
60
68
  )
61
69
  self.options = self.get_client_options()
62
70
  self._expected_operation_type = None
63
- self.load_client()
71
+ self._download_model()
64
72
 
65
73
  @staticmethod
66
- def _extract_string_output(result) -> str:
74
+ def _extract_string_output(response: list[dict]) -> str:
67
75
  """
68
- Extracts the first generated string from Hugging Face pipeline output,
69
- regardless of whether it's plain text-generation or chat-style output.
76
+ Extracts the first generated string from Hugging Face pipeline output
70
77
  """
71
- if not isinstance(result, list) or len(result) == 0:
78
+ if not isinstance(response, list) or len(response) == 0:
72
79
  raise ValueError("Empty or invalid pipeline output")
73
-
74
- return result[0].get("generated_text")
80
+ if len(response) != 1:
81
+ raise mlrun.errors.MLRunInvalidArgumentError(
82
+ "HuggingFaceProvider: extracting string from response is only supported for single-response outputs"
83
+ )
84
+ return response[0].get("generated_text")
75
85
 
76
86
  @classmethod
77
87
  def parse_endpoint_and_path(cls, endpoint, subpath) -> (str, str):
@@ -81,6 +91,120 @@ class HuggingFaceProvider(ModelProvider):
81
91
  subpath = ""
82
92
  return endpoint, subpath
83
93
 
94
+ @property
95
+ def client(self) -> Any:
96
+ """
97
+ Lazily return the HuggingFace-pipeline client.
98
+
99
+ If the client has not been initialized yet, it will be created
100
+ by calling `load_client`.
101
+ """
102
+ self.load_client()
103
+ return self._client
104
+
105
+ def _download_model(self):
106
+ """
107
+ Pre-downloads model files locally to prevent race conditions in multiprocessing.
108
+
109
+ Uses snapshot_download with local_dir_use_symlinks=False to ensure proper
110
+ file copying for safe concurrent access across multiple processes.
111
+
112
+ :raises:
113
+ ImportError: If huggingface_hub package is not installed.
114
+ """
115
+ try:
116
+ from huggingface_hub import snapshot_download
117
+
118
+ # Download the model and tokenizer files directly to the cache.
119
+ snapshot_download(
120
+ repo_id=self.model,
121
+ local_dir_use_symlinks=False,
122
+ token=self._get_secret_or_env("HF_TOKEN") or None,
123
+ )
124
+ except ImportError as exc:
125
+ raise ImportError("huggingface_hub package is not installed") from exc
126
+
127
+ def _response_handler(
128
+ self,
129
+ response: Union[str, list],
130
+ invoke_response_format: InvokeResponseFormat = InvokeResponseFormat.FULL,
131
+ messages: Union[str, list[str], "ChatType", list["ChatType"]] = None,
132
+ **kwargs,
133
+ ) -> Union[str, list, dict[str, Any]]:
134
+ """
135
+ Processes and formats the raw response from the HuggingFace pipeline according to the specified format.
136
+
137
+ The response should exclude the user’s input (no repetition in the output).
138
+ This can be accomplished by invoking the pipeline with `return_full_text=False`.
139
+
140
+ :param response: The raw response from the HuggingFace pipeline, typically a list of dictionaries
141
+ containing generated text sequences.
142
+ :param invoke_response_format: Determines how the response should be processed and returned. Options:
143
+
144
+ - STRING: Return only the main generated content as a string,
145
+ for single-answer responses.
146
+ - USAGE: Return a dictionary combining the string response with
147
+ token usage statistics:
148
+
149
+ .. code-block:: json
150
+
151
+ {
152
+ "answer": "<generated_text>",
153
+ "usage": {
154
+ "prompt_tokens": <int>,
155
+ "completion_tokens": <int>,
156
+ "total_tokens": <int>
157
+ }
158
+ }
159
+
160
+ Note: Token counts are estimated after answer generation and
161
+ may differ from the actual tokens generated by the model due to
162
+ internal decoding behavior and implementation details.
163
+
164
+ - FULL: Return the full raw response object.
165
+
166
+ :param messages: The original input messages used for token count estimation in USAGE mode.
167
+ Can be a string, list of strings, or chat format messages.
168
+ :param kwargs: Additional parameters for response processing.
169
+
170
+ :return: The processed response in the format specified by `invoke_response_format`.
171
+ Can be a string, dictionary, or the original response object.
172
+
173
+ :raises MLRunInvalidArgumentError: If extracting the string response fails.
174
+ :raises MLRunRuntimeError: If applying the chat template to the model fails during token usage calculation.
175
+ """
176
+ if InvokeResponseFormat.is_str_response(invoke_response_format.value):
177
+ str_response = self._extract_string_output(response)
178
+ if invoke_response_format == InvokeResponseFormat.STRING:
179
+ return str_response
180
+ if invoke_response_format == InvokeResponseFormat.USAGE:
181
+ tokenizer = self.client.tokenizer
182
+ if not isinstance(messages, str):
183
+ try:
184
+ messages = tokenizer.apply_chat_template(
185
+ messages, tokenize=False, add_generation_prompt=True
186
+ )
187
+ except Exception as e:
188
+ raise mlrun.errors.MLRunRuntimeError(
189
+ f"Failed to apply chat template using the tokenizer for model '{self.model}'. "
190
+ "This may indicate that the tokenizer does not support chat formatting, "
191
+ "or that the input format is invalid. "
192
+ f"Original error: {e}"
193
+ )
194
+ prompt_tokens = len(tokenizer.encode(messages))
195
+ completion_tokens = len(tokenizer.encode(str_response))
196
+ total_tokens = prompt_tokens + completion_tokens
197
+ usage = {
198
+ "prompt_tokens": prompt_tokens,
199
+ "completion_tokens": completion_tokens,
200
+ "total_tokens": total_tokens,
201
+ }
202
+ response = {
203
+ UsageResponseKeys.ANSWER: str_response,
204
+ UsageResponseKeys.USAGE: usage,
205
+ }
206
+ return response
207
+
84
208
  def load_client(self) -> None:
85
209
  """
86
210
  Initializes the Hugging Face pipeline using the provided options.
@@ -91,15 +215,20 @@ class HuggingFaceProvider(ModelProvider):
91
215
 
92
216
  Note: Hugging Face pipelines are synchronous and do not support async invocation.
93
217
 
94
- Raises:
218
+ :raises:
95
219
  ImportError: If the `transformers` package is not installed.
96
220
  """
221
+ if self._client:
222
+ return
97
223
  try:
98
224
  from transformers import pipeline, AutoModelForCausalLM # noqa
99
225
  from transformers import AutoTokenizer # noqa
100
226
  from transformers.pipelines.base import Pipeline # noqa
101
227
 
102
- self._client = pipeline(model=self.model, **self.options)
228
+ self.options["model_kwargs"] = self.options.get("model_kwargs", {})
229
+ self.options["model_kwargs"]["local_files_only"] = True
230
+ with self._client_lock:
231
+ self._client = pipeline(model=self.model, **self.options)
103
232
  self._expected_operation_type = Pipeline
104
233
  except ImportError as exc:
105
234
  raise ImportError("transformers package is not installed") from exc
@@ -117,25 +246,40 @@ class HuggingFaceProvider(ModelProvider):
117
246
 
118
247
  def custom_invoke(
119
248
  self, operation: Optional["Pipeline"] = None, **invoke_kwargs
120
- ) -> Optional[T]:
249
+ ) -> Union[list, dict, Any]:
121
250
  """
122
- HuggingFace implementation of `ModelProvider.custom_invoke`.
123
- Use the default config in provider client/ user defined client:
251
+ Invokes a HuggingFace pipeline operation with the given keyword arguments.
252
+
253
+ This method provides flexibility to use a custom pipeline object for specific tasks
254
+ (e.g., image classification, sentiment analysis).
255
+
256
+ The operation must be a Pipeline object from the transformers library that accepts keyword arguments.
124
257
 
125
258
  Example:
126
- ```python
259
+ ```python
260
+ from transformers import pipeline
261
+ from PIL import Image
262
+
263
+ # Using custom pipeline for image classification
127
264
  image = Image.open(image_path)
128
- pipeline_object = pipeline("image-classification", model="microsoft/resnet-50")
265
+ pipeline_object = pipeline("image-classification", model="microsoft/resnet-50")
129
266
  result = hf_provider.custom_invoke(
130
267
  pipeline_object,
131
268
  inputs=image,
132
269
  )
133
- ```
270
+ ```
134
271
 
272
+ :param operation: A Pipeline object from the transformers library.
273
+ If not provided, defaults to the provider's configured pipeline.
274
+ :param invoke_kwargs: Keyword arguments to pass to the pipeline operation.
275
+ These are merged with `default_invoke_kwargs` and may include
276
+ parameters such as `inputs`, `max_length`, `temperature`, or task-specific options.
135
277
 
136
- :param operation: A pipeline object
137
- :param invoke_kwargs: Keyword arguments to pass to the operation.
138
- :return: The full response returned by the operation.
278
+ :return: The full response returned by the pipeline operation.
279
+ Format depends on the pipeline task (list for text generation,
280
+ dict for classification, etc.).
281
+
282
+ :raises MLRunInvalidArgumentError: If the operation is not a valid Pipeline object.
139
283
 
140
284
  """
141
285
  invoke_kwargs = self.get_invoke_kwargs(invoke_kwargs)
@@ -150,34 +294,74 @@ class HuggingFaceProvider(ModelProvider):
150
294
 
151
295
  def invoke(
152
296
  self,
153
- messages: Union[str, list[str], ChatType, list[ChatType]] = None,
154
- as_str: bool = False,
297
+ messages: Union[str, list[str], "ChatType", list["ChatType"]],
298
+ invoke_response_format: InvokeResponseFormat = InvokeResponseFormat.FULL,
155
299
  **invoke_kwargs,
156
- ) -> Optional[Union[str, list, T]]:
300
+ ) -> Union[str, list, dict[str, Any]]:
157
301
  """
158
- HuggingFace-specific implementation of `ModelProvider.invoke`.
159
- Invokes a HuggingFace model operation using the synchronous client.
160
- For complete usage details, refer to `ModelProvider.invoke`.
302
+ HuggingFace-specific implementation of model invocation using the synchronous pipeline client.
303
+ Invokes a HuggingFace model operation for text generation tasks.
304
+
305
+ Note: Ensure your environment has sufficient computational resources (CPU/GPU and memory) to run the model.
306
+
161
307
  :param messages:
162
- Same as ModelProvider.invoke.
308
+ Input for the text generation model. Can be provided in multiple formats:
309
+
310
+ - A single string: Direct text input for generation
311
+ - A list of strings: Multiple text inputs for batch processing
312
+ - Chat format: A list of dictionaries with "role" and "content" keys:
313
+
314
+ .. code-block:: json
315
+
316
+ [
317
+ {"role": "system", "content": "You are a helpful assistant."},
318
+ {"role": "user", "content": "What is the capital of France?"}
319
+ ]
320
+
321
+ :param invoke_response_format: InvokeResponseFormat
322
+ Specifies the format of the returned response. Options:
163
323
 
164
- :param as_str:
165
- If `True`, returns only the main content from a single response
166
- (intended for single-response use cases).
167
- If `False`, returns the full response object, whose type depends on
168
- the client (e.g., `pipeline`).
324
+ - "string": Returns only the generated text content, extracted from a single response.
325
+ - "usage": Combines the generated text with metadata (e.g., token usage), returning a dictionary:
326
+
327
+ .. code-block:: json
328
+ {
329
+ "answer": "<generated_text>",
330
+ "usage": {
331
+ "prompt_tokens": <int>,
332
+ "completion_tokens": <int>,
333
+ "total_tokens": <int>
334
+ }
335
+ }
336
+
337
+ Note: For usage mode, the model tokenizer should support apply_chat_template.
338
+
339
+ - "full": Returns the raw response object from the HuggingFace model,
340
+ typically a list of generated sequences (dictionaries).
341
+ This format does not include token usage statistics.
169
342
 
170
343
  :param invoke_kwargs:
171
- Same as ModelProvider.invoke.
172
- :return: Same as ModelProvider.invoke.
344
+ Additional keyword arguments passed to the HuggingFace pipeline.
345
+
346
+ :return:
347
+ A string, dictionary, or list of model outputs, depending on `invoke_response_format`.
348
+
349
+ :raises MLRunInvalidArgumentError:
350
+ If the pipeline task is not "text-generation" or if the response contains multiple outputs when extracting
351
+ string content.
352
+ :raises MLRunRuntimeError:
353
+ If using "usage" response mode and the model tokenizer does not support chat template formatting.
173
354
  """
174
355
  if self.client.task != "text-generation":
175
356
  raise mlrun.errors.MLRunInvalidArgumentError(
176
357
  "HuggingFaceProvider.invoke supports text-generation task only"
177
358
  )
178
- if as_str:
359
+ if InvokeResponseFormat.is_str_response(invoke_response_format.value):
179
360
  invoke_kwargs["return_full_text"] = False
180
361
  response = self.custom_invoke(text_inputs=messages, **invoke_kwargs)
181
- if as_str:
182
- return self._extract_string_output(response)
362
+ response = self._response_handler(
363
+ messages=messages,
364
+ response=response,
365
+ invoke_response_format=invoke_response_format,
366
+ )
183
367
  return response
@@ -0,0 +1,87 @@
1
+ # Copyright 2023 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Union
16
+
17
+ import mlrun
18
+ from mlrun.datastore.model_provider.model_provider import (
19
+ InvokeResponseFormat,
20
+ ModelProvider,
21
+ UsageResponseKeys,
22
+ )
23
+
24
+
25
+ class MockModelProvider(ModelProvider):
26
+ support_async = False
27
+
28
+ def __init__(
29
+ self,
30
+ parent,
31
+ kind,
32
+ name,
33
+ endpoint="",
34
+ secrets: Optional[dict] = None,
35
+ default_invoke_kwargs: Optional[dict] = None,
36
+ ):
37
+ super().__init__(
38
+ parent=parent, name=name, kind=kind, endpoint=endpoint, secrets=secrets
39
+ )
40
+ self.default_invoke_kwargs = default_invoke_kwargs or {}
41
+ self._client = None
42
+ self._async_client = None
43
+
44
+ @staticmethod
45
+ def _extract_string_output(response: Any) -> str:
46
+ """
47
+ Extracts string response from response object
48
+ """
49
+ pass
50
+
51
+ def load_client(self) -> None:
52
+ """
53
+ Initializes the SDK client for the model provider with the given keyword arguments
54
+ and assigns it to an instance attribute (e.g., self._client).
55
+
56
+ Subclasses should override this method to:
57
+ - Create and configure the provider-specific client instance.
58
+ - Assign the client instance to self._client.
59
+ """
60
+
61
+ pass
62
+
63
+ def invoke(
64
+ self,
65
+ messages: Union[list[dict], Any],
66
+ invoke_response_format: InvokeResponseFormat = InvokeResponseFormat.FULL,
67
+ **invoke_kwargs,
68
+ ) -> Union[str, dict[str, Any], Any]:
69
+ if invoke_response_format == InvokeResponseFormat.STRING:
70
+ return (
71
+ "You are using a mock model provider, no actual inference is performed."
72
+ )
73
+ elif invoke_response_format == InvokeResponseFormat.FULL:
74
+ return {
75
+ UsageResponseKeys.USAGE: {"prompt_tokens": 0, "completion_tokens": 0},
76
+ UsageResponseKeys.ANSWER: "You are using a mock model provider, no actual inference is performed.",
77
+ "extra": {},
78
+ }
79
+ elif invoke_response_format == InvokeResponseFormat.USAGE:
80
+ return {
81
+ UsageResponseKeys.ANSWER: "You are using a mock model provider, no actual inference is performed.",
82
+ UsageResponseKeys.USAGE: {"prompt_tokens": 0, "completion_tokens": 0},
83
+ }
84
+ else:
85
+ raise mlrun.errors.MLRunInvalidArgumentError(
86
+ f"Unsupported invoke response format: {invoke_response_format}"
87
+ )