mlrun 1.8.0rc10__py3-none-any.whl → 1.8.0rc12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

mlrun/common/constants.py CHANGED
@@ -25,6 +25,7 @@ MYSQL_MEDIUMBLOB_SIZE_BYTES = 16 * 1024 * 1024
25
25
  MLRUN_LABEL_PREFIX = "mlrun/"
26
26
  DASK_LABEL_PREFIX = "dask.org/"
27
27
  NUCLIO_LABEL_PREFIX = "nuclio.io/"
28
+ RESERVED_TAG_NAME_LATEST = "latest"
28
29
 
29
30
 
30
31
  class MLRunInternalLabels:
@@ -146,8 +146,10 @@ from .model_monitoring import (
146
146
  GrafanaTable,
147
147
  GrafanaTimeSeriesTarget,
148
148
  ModelEndpoint,
149
+ ModelEndpointCreationStrategy,
149
150
  ModelEndpointList,
150
151
  ModelEndpointMetadata,
152
+ ModelEndpointSchema,
151
153
  ModelEndpointSpec,
152
154
  ModelEndpointStatus,
153
155
  ModelMonitoringMode,
@@ -26,6 +26,7 @@ from .constants import (
26
26
  FileTargetKind,
27
27
  FunctionURI,
28
28
  MetricData,
29
+ ModelEndpointCreationStrategy,
29
30
  ModelEndpointMonitoringMetricType,
30
31
  ModelEndpointSchema,
31
32
  ModelEndpointTarget,
@@ -71,6 +71,12 @@ class ModelEndpointSchema(MonitoringStrEnum):
71
71
  DRIFT_MEASURES = "drift_measures"
72
72
 
73
73
 
74
+ class ModelEndpointCreationStrategy(MonitoringStrEnum):
75
+ INPLACE = "inplace"
76
+ ARCHIVE = "archive"
77
+ OVERWRITE = "overwrite"
78
+
79
+
74
80
  class EventFieldType:
75
81
  FUNCTION_URI = "function_uri"
76
82
  FUNCTION = "function"
@@ -117,6 +117,10 @@ class ModelEndpointMetadata(ObjectMetadata, ModelEndpointParser):
117
117
  endpoint_type: EndpointType = EndpointType.NODE_EP
118
118
  uid: Optional[constr(regex=MODEL_ENDPOINT_ID_PATTERN)]
119
119
 
120
+ @classmethod
121
+ def mutable_fields(cls):
122
+ return ["labels"]
123
+
120
124
 
121
125
  class ModelEndpointSpec(ObjectSpec, ModelEndpointParser):
122
126
  model_uid: Optional[str] = ""
@@ -136,6 +140,21 @@ class ModelEndpointSpec(ObjectSpec, ModelEndpointParser):
136
140
  children_uids: Optional[list[str]] = []
137
141
  monitoring_feature_set_uri: Optional[str] = ""
138
142
 
143
+ @classmethod
144
+ def mutable_fields(cls):
145
+ return [
146
+ "model_uid",
147
+ "model_name",
148
+ "model_db_key",
149
+ "model_tag",
150
+ "model_class",
151
+ "function_uid",
152
+ "feature_names",
153
+ "label_names",
154
+ "children",
155
+ "children_uids",
156
+ ]
157
+
139
158
 
140
159
  class ModelEndpointStatus(ObjectStatus, ModelEndpointParser):
141
160
  state: Optional[str] = "unknown" # will be updated according to the function state
@@ -152,6 +171,14 @@ class ModelEndpointStatus(ObjectStatus, ModelEndpointParser):
152
171
  drift_measures: Optional[dict] = {}
153
172
  drift_measures_timestamp: Optional[datetime] = None
154
173
 
174
+ @classmethod
175
+ def mutable_fields(cls):
176
+ return [
177
+ "monitoring_mode",
178
+ "first_request",
179
+ "last_request",
180
+ ]
181
+
155
182
 
156
183
  class ModelEndpoint(BaseModel):
157
184
  kind: ObjectKind = Field(ObjectKind.model_endpoint, const=True)
@@ -159,6 +186,14 @@ class ModelEndpoint(BaseModel):
159
186
  spec: ModelEndpointSpec
160
187
  status: ModelEndpointStatus
161
188
 
189
+ @classmethod
190
+ def mutable_fields(cls):
191
+ return (
192
+ ModelEndpointMetadata.mutable_fields()
193
+ + ModelEndpointSpec.mutable_fields()
194
+ + ModelEndpointStatus.mutable_fields()
195
+ )
196
+
162
197
  def flat_dict(self) -> dict[str, Any]:
163
198
  """Generate a flattened `ModelEndpoint` dictionary. The flattened dictionary result is important for storing
164
199
  the model endpoint object in the database.
mlrun/db/base.py CHANGED
@@ -666,6 +666,7 @@ class RunDBInterface(ABC):
666
666
  def create_model_endpoint(
667
667
  self,
668
668
  model_endpoint: mlrun.common.schemas.ModelEndpoint,
669
+ creation_strategy: mlrun.common.schemas.ModelEndpointCreationStrategy = "inplace",
669
670
  ) -> mlrun.common.schemas.ModelEndpoint:
670
671
  pass
671
672
 
@@ -688,6 +689,7 @@ class RunDBInterface(ABC):
688
689
  function_name: Optional[str] = None,
689
690
  function_tag: Optional[str] = None,
690
691
  model_name: Optional[str] = None,
692
+ model_tag: Optional[str] = None,
691
693
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
692
694
  start: Optional[datetime.datetime] = None,
693
695
  end: Optional[datetime.datetime] = None,
mlrun/db/httpdb.py CHANGED
@@ -3582,11 +3582,17 @@ class HTTPRunDB(RunDBInterface):
3582
3582
  def create_model_endpoint(
3583
3583
  self,
3584
3584
  model_endpoint: mlrun.common.schemas.ModelEndpoint,
3585
+ creation_strategy: mlrun.common.schemas.ModelEndpointCreationStrategy = "inplace",
3585
3586
  ) -> mlrun.common.schemas.ModelEndpoint:
3586
3587
  """
3587
3588
  Creates a DB record with the given model_endpoint record.
3588
3589
 
3589
3590
  :param model_endpoint: An object representing the model endpoint.
3591
+ :param creation_strategy: model endpoint creation strategy :
3592
+ * overwrite - Create a new model endpoint and delete the last old one if it exists.
3593
+ * inplace - Use the existing model endpoint if it already exists (default).
3594
+ * archive - Preserve the old model endpoint and create a new one,
3595
+ tagging it as the latest.
3590
3596
 
3591
3597
  :return: The created model endpoint object.
3592
3598
  """
@@ -3596,6 +3602,9 @@ class HTTPRunDB(RunDBInterface):
3596
3602
  method=mlrun.common.types.HTTPMethod.POST,
3597
3603
  path=path,
3598
3604
  body=model_endpoint.json(),
3605
+ params={
3606
+ "creation_strategy": creation_strategy,
3607
+ },
3599
3608
  )
3600
3609
  return mlrun.common.schemas.ModelEndpoint(**response.json())
3601
3610
 
@@ -3637,6 +3646,7 @@ class HTTPRunDB(RunDBInterface):
3637
3646
  function_name: Optional[str] = None,
3638
3647
  function_tag: Optional[str] = None,
3639
3648
  model_name: Optional[str] = None,
3649
+ model_tag: Optional[str] = None,
3640
3650
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
3641
3651
  start: Optional[datetime] = None,
3642
3652
  end: Optional[datetime] = None,
@@ -3653,6 +3663,7 @@ class HTTPRunDB(RunDBInterface):
3653
3663
  :param function_name: The name of the function
3654
3664
  :param function_tag: The tag of the function
3655
3665
  :param model_name: The name of the model
3666
+ :param model_tag: The tag of the model
3656
3667
  :param labels: A list of labels to filter by. (see mlrun.common.schemas.LabelsModel)
3657
3668
  :param start: The start time to filter by.Corresponding to the `created` field.
3658
3669
  :param end: The end time to filter by. Corresponding to the `created` field.
@@ -3671,6 +3682,7 @@ class HTTPRunDB(RunDBInterface):
3671
3682
  params={
3672
3683
  "name": name,
3673
3684
  "model_name": model_name,
3685
+ "model_tag": model_tag,
3674
3686
  "function_name": function_name,
3675
3687
  "function_tag": function_tag,
3676
3688
  "label": labels,
mlrun/db/nopdb.py CHANGED
@@ -575,6 +575,7 @@ class NopDB(RunDBInterface):
575
575
  def create_model_endpoint(
576
576
  self,
577
577
  model_endpoint: mlrun.common.schemas.ModelEndpoint,
578
+ creation_strategy: mlrun.common.schemas.ModelEndpointCreationStrategy = "inplace",
578
579
  ) -> mlrun.common.schemas.ModelEndpoint:
579
580
  pass
580
581
 
@@ -595,6 +596,7 @@ class NopDB(RunDBInterface):
595
596
  function_name: Optional[str] = None,
596
597
  function_tag: Optional[str] = None,
597
598
  model_name: Optional[str] = None,
599
+ model_tag: Optional[str] = None,
598
600
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
599
601
  start: Optional[datetime.datetime] = None,
600
602
  end: Optional[datetime.datetime] = None,
@@ -671,7 +671,7 @@ class SetEventMetadata(MapClass):
671
671
 
672
672
  self._tagging_funcs = []
673
673
 
674
- def post_init(self, mode="sync"):
674
+ def post_init(self, mode="sync", **kwargs):
675
675
  def add_metadata(name, path, operator=str):
676
676
  def _add_meta(event):
677
677
  value = get_in(event.body, path)
@@ -54,9 +54,10 @@ def get_or_create_model_endpoint(
54
54
  model_endpoint_name: str = "",
55
55
  endpoint_id: str = "",
56
56
  function_name: str = "",
57
+ function_tag: str = "latest",
57
58
  context: typing.Optional["mlrun.MLClientCtx"] = None,
58
59
  sample_set_statistics: typing.Optional[dict[str, typing.Any]] = None,
59
- monitoring_mode: mm_constants.ModelMonitoringMode = mm_constants.ModelMonitoringMode.disabled,
60
+ monitoring_mode: mm_constants.ModelMonitoringMode = mm_constants.ModelMonitoringMode.enabled,
60
61
  db_session=None,
61
62
  ) -> ModelEndpoint:
62
63
  """
@@ -70,8 +71,8 @@ def get_or_create_model_endpoint(
70
71
  under this endpoint (applicable only to new endpoint_id).
71
72
  :param endpoint_id: Model endpoint unique ID. If not exist in DB, will generate a new record based
72
73
  on the provided `endpoint_id`.
73
- :param function_name: If a new model endpoint is created, use this function name for generating the
74
- function URI (applicable only to new endpoint_id).
74
+ :param function_name: If a new model endpoint is created, use this function name.
75
+ :param function_tag: If a new model endpoint is created, use this function tag.
75
76
  :param context: MLRun context. If `function_name` not provided, use the context to generate the
76
77
  full function hash.
77
78
  :param sample_set_statistics: Dictionary of sample set statistics that will be used as a reference data for
@@ -86,28 +87,32 @@ def get_or_create_model_endpoint(
86
87
  if not db_session:
87
88
  # Generate a runtime database
88
89
  db_session = mlrun.get_run_db()
90
+ model_endpoint = None
89
91
  try:
90
- model_endpoint = db_session.get_model_endpoint(
91
- project=project,
92
- name=model_endpoint_name,
93
- endpoint_id=endpoint_id,
94
- function_name=function_name,
95
- )
96
- # If other fields provided, validate that they are correspond to the existing model endpoint data
97
- _model_endpoint_validations(
98
- model_endpoint=model_endpoint,
99
- model_path=model_path,
100
- sample_set_statistics=sample_set_statistics,
101
- )
92
+ if endpoint_id:
93
+ model_endpoint = db_session.get_model_endpoint(
94
+ project=project,
95
+ name=model_endpoint_name,
96
+ endpoint_id=endpoint_id,
97
+ )
98
+ # If other fields provided, validate that they are correspond to the existing model endpoint data
99
+ _model_endpoint_validations(
100
+ model_endpoint=model_endpoint,
101
+ model_path=model_path,
102
+ sample_set_statistics=sample_set_statistics,
103
+ )
102
104
 
103
105
  except mlrun.errors.MLRunNotFoundError:
104
106
  # Create a new model endpoint with the provided details
107
+ pass
108
+ if not model_endpoint:
105
109
  model_endpoint = _generate_model_endpoint(
106
110
  project=project,
107
111
  db_session=db_session,
108
112
  model_path=model_path,
109
113
  model_endpoint_name=model_endpoint_name,
110
114
  function_name=function_name,
115
+ function_tag=function_tag,
111
116
  context=context,
112
117
  sample_set_statistics=sample_set_statistics,
113
118
  monitoring_mode=monitoring_mode,
@@ -333,9 +338,10 @@ def _generate_model_endpoint(
333
338
  model_path: str,
334
339
  model_endpoint_name: str,
335
340
  function_name: str,
341
+ function_tag: str,
336
342
  context: "mlrun.MLClientCtx",
337
343
  sample_set_statistics: dict[str, typing.Any],
338
- monitoring_mode: mm_constants.ModelMonitoringMode = mm_constants.ModelMonitoringMode.disabled,
344
+ monitoring_mode: mm_constants.ModelMonitoringMode = mm_constants.ModelMonitoringMode.enabled,
339
345
  ) -> ModelEndpoint:
340
346
  """
341
347
  Write a new model endpoint record.
@@ -345,8 +351,8 @@ def _generate_model_endpoint(
345
351
  :param db_session: A session that manages the current dialog with the database.
346
352
  :param model_path: The model Store path.
347
353
  :param model_endpoint_name: Model endpoint name will be presented under the new model endpoint.
348
- :param function_name: If a new model endpoint is created, use this function name for generating the
349
- function URI.
354
+ :param function_name: If a new model endpoint is created, use this function name.
355
+ :param function_tag: If a new model endpoint is created, use this function tag.
350
356
  :param context: MLRun context. If function_name not provided, use the context to generate the
351
357
  full function hash.
352
358
  :param sample_set_statistics: Dictionary of sample set statistics that will be used as a reference data for
@@ -374,7 +380,8 @@ def _generate_model_endpoint(
374
380
  endpoint_type=mlrun.common.schemas.model_monitoring.EndpointType.BATCH_EP,
375
381
  ),
376
382
  spec=mlrun.common.schemas.ModelEndpointSpec(
377
- function_name=function_name,
383
+ function_name=function_name or "function",
384
+ function_tag=function_tag or "latest",
378
385
  model_name=model_obj.metadata.key if model_obj else None,
379
386
  model_uid=model_obj.metadata.uid if model_obj else None,
380
387
  model_tag=model_obj.metadata.tag if model_obj else None,
@@ -148,6 +148,44 @@ class ModelMonitoringApplicationBase(MonitoringApplicationToDict, ABC):
148
148
  )
149
149
  return start, end
150
150
 
151
+ @classmethod
152
+ def deploy(
153
+ cls,
154
+ func_name: str,
155
+ func_path: Optional[str] = None,
156
+ image: Optional[str] = None,
157
+ handler: Optional[str] = None,
158
+ with_repo: Optional[bool] = False,
159
+ tag: Optional[str] = None,
160
+ requirements: Optional[Union[str, list[str]]] = None,
161
+ requirements_file: str = "",
162
+ **application_kwargs,
163
+ ) -> None:
164
+ """
165
+ Set the application to the current project and deploy it as a Nuclio serving function.
166
+ Required for your model monitoring application to work as a part of the model monitoring framework.
167
+
168
+ :param func_name: The name of the function.
169
+ :param func_path: The path of the function, :code:`None` refers to the current Jupyter notebook.
170
+
171
+ For the other arguments, refer to
172
+ :py:meth:`~mlrun.projects.MlrunProject.set_model_monitoring_function`.
173
+ """
174
+ project = cast("mlrun.MlrunProject", mlrun.get_current_project())
175
+ function = project.set_model_monitoring_function(
176
+ name=func_name,
177
+ func=func_path,
178
+ application_class=cls.__name__,
179
+ handler=handler,
180
+ image=image,
181
+ with_repo=with_repo,
182
+ requirements=requirements,
183
+ requirements_file=requirements_file,
184
+ tag=tag,
185
+ **application_kwargs,
186
+ )
187
+ function.deploy()
188
+
151
189
  @classmethod
152
190
  def evaluate(
153
191
  cls,
@@ -175,10 +213,10 @@ class ModelMonitoringApplicationBase(MonitoringApplicationToDict, ABC):
175
213
  :param func_name: The name of the function. If not passed, the class name is used.
176
214
  :param tag: An optional tag for the function.
177
215
  :param run_local: Whether to run the function locally or remotely.
178
- :param sample_df: Optional - pandas data-frame as the current dataset.
179
- When set, it replaces the data read from the model endpoint's offline source.
180
- :param feature_stats: Optional - statistics dictionary of the reference data.
181
- When set, it overrides the model endpoint's feature stats.
216
+ :param sample_data: Optional - pandas data-frame as the current dataset.
217
+ When set, it replaces the data read from the model endpoint's offline source.
218
+ :param reference_data: Optional - pandas data-frame of the reference dataset.
219
+ When set, its statistics override the model endpoint's feature statistics.
182
220
  :param image: Docker image to run the job on.
183
221
  :param with_repo: Whether to clone the current repo to the build source.
184
222
  :param requirements: List of Python requirements to be installed in the image.
mlrun/projects/project.py CHANGED
@@ -2117,13 +2117,13 @@ class MlrunProject(ModelObj):
2117
2117
 
2118
2118
  def set_model_monitoring_function(
2119
2119
  self,
2120
+ name: str,
2120
2121
  func: typing.Union[str, mlrun.runtimes.RemoteRuntime, None] = None,
2121
2122
  application_class: typing.Union[
2122
2123
  str, mm_app.ModelMonitoringApplicationBase, None
2123
2124
  ] = None,
2124
- name: Optional[str] = None,
2125
2125
  image: Optional[str] = None,
2126
- handler=None,
2126
+ handler: Optional[str] = None,
2127
2127
  with_repo: Optional[bool] = None,
2128
2128
  tag: Optional[str] = None,
2129
2129
  requirements: Optional[typing.Union[str, list[str]]] = None,
@@ -2135,7 +2135,7 @@ class MlrunProject(ModelObj):
2135
2135
  Note: to deploy the function after linking it to the project,
2136
2136
  call `fn.deploy()` where `fn` is the object returned by this method.
2137
2137
 
2138
- examples::
2138
+ Example::
2139
2139
 
2140
2140
  project.set_model_monitoring_function(
2141
2141
  name="myApp", application_class="MyApp", image="mlrun/mlrun"
@@ -2144,8 +2144,7 @@ class MlrunProject(ModelObj):
2144
2144
  :param func: Remote function object or spec/code URL. :code:`None` refers to the current
2145
2145
  notebook.
2146
2146
  :param name: Name of the function (under the project), can be specified with a tag to support
2147
- versions (e.g. myfunc:v1)
2148
- Default: job
2147
+ versions (e.g. myfunc:v1).
2149
2148
  :param image: Docker image to be used, can also be specified in
2150
2149
  the function object/yaml
2151
2150
  :param handler: Default function handler to invoke (can only be set with .py/.ipynb files)
@@ -2183,12 +2182,13 @@ class MlrunProject(ModelObj):
2183
2182
 
2184
2183
  def create_model_monitoring_function(
2185
2184
  self,
2185
+ name: str,
2186
2186
  func: Optional[str] = None,
2187
2187
  application_class: typing.Union[
2188
2188
  str,
2189
2189
  mm_app.ModelMonitoringApplicationBase,
2190
+ None,
2190
2191
  ] = None,
2191
- name: Optional[str] = None,
2192
2192
  image: Optional[str] = None,
2193
2193
  handler: Optional[str] = None,
2194
2194
  with_repo: Optional[bool] = None,
@@ -2200,16 +2200,15 @@ class MlrunProject(ModelObj):
2200
2200
  """
2201
2201
  Create a monitoring function object without setting it to the project
2202
2202
 
2203
- examples::
2203
+ Example::
2204
2204
 
2205
2205
  project.create_model_monitoring_function(
2206
- application_class_name="MyApp", image="mlrun/mlrun", name="myApp"
2206
+ name="myApp", application_class="MyApp", image="mlrun/mlrun"
2207
2207
  )
2208
2208
 
2209
2209
  :param func: The function's code URL. :code:`None` refers to the current notebook.
2210
2210
  :param name: Name of the function, can be specified with a tag to support
2211
- versions (e.g. myfunc:v1)
2212
- Default: job
2211
+ versions (e.g. myfunc:v1).
2213
2212
  :param image: Docker image to be used, can also be specified in
2214
2213
  the function object/yaml
2215
2214
  :param handler: Default function handler to invoke (can only be set with .py/.ipynb files)
@@ -3551,6 +3550,7 @@ class MlrunProject(ModelObj):
3551
3550
  self,
3552
3551
  name: Optional[str] = None,
3553
3552
  model_name: Optional[str] = None,
3553
+ model_tag: Optional[str] = None,
3554
3554
  function_name: Optional[str] = None,
3555
3555
  function_tag: Optional[str] = None,
3556
3556
  labels: Optional[list[str]] = None,
@@ -3565,12 +3565,13 @@ class MlrunProject(ModelObj):
3565
3565
  model endpoint. This functions supports filtering by the following parameters:
3566
3566
  1) name
3567
3567
  2) model_name
3568
- 3) function_name
3569
- 4) function_tag
3570
- 5) labels
3571
- 6) top level
3572
- 7) uids
3573
- 8) start and end time, corresponding to the `created` field.
3568
+ 3) model_tag
3569
+ 4) function_name
3570
+ 5) function_tag
3571
+ 6) labels
3572
+ 7) top level
3573
+ 8) uids
3574
+ 9) start and end time, corresponding to the `created` field.
3574
3575
  By default, when no filters are applied, all available endpoints for the given project will be listed.
3575
3576
 
3576
3577
  In addition, this functions provides a facade for listing endpoint related metrics. This facade is time-based
@@ -3599,6 +3600,7 @@ class MlrunProject(ModelObj):
3599
3600
  project=self.name,
3600
3601
  name=name,
3601
3602
  model_name=model_name,
3603
+ model_tag=model_tag,
3602
3604
  function_name=function_name,
3603
3605
  function_tag=function_tag,
3604
3606
  labels=labels,
@@ -22,7 +22,7 @@ import nuclio
22
22
  from nuclio import KafkaTrigger
23
23
 
24
24
  import mlrun
25
- import mlrun.common.schemas
25
+ import mlrun.common.schemas as schemas
26
26
  from mlrun.datastore import get_kafka_brokers_from_dict, parse_kafka_url
27
27
  from mlrun.model import ObjectList
28
28
  from mlrun.runtimes.function_reference import FunctionReference
@@ -362,6 +362,9 @@ class ServingRuntime(RemoteRuntime):
362
362
  handler: Optional[str] = None,
363
363
  router_step: Optional[str] = None,
364
364
  child_function: Optional[str] = None,
365
+ creation_strategy: Optional[
366
+ schemas.ModelEndpointCreationStrategy
367
+ ] = schemas.ModelEndpointCreationStrategy.INPLACE,
365
368
  **class_args,
366
369
  ):
367
370
  """add ml model and/or route to the function.
@@ -384,6 +387,11 @@ class ServingRuntime(RemoteRuntime):
384
387
  :param router_step: router step name (to determine which router we add the model to in graphs
385
388
  with multiple router steps)
386
389
  :param child_function: child function name, when the model runs in a child function
390
+ :param creation_strategy: model endpoint creation strategy :
391
+ * overwrite - Create a new model endpoint and delete the last old one if it exists.
392
+ * inplace - Use the existing model endpoint if it already exists (default).
393
+ * archive - Preserve the old model endpoint and create a new one,
394
+ tagging it as the latest.
387
395
  :param class_args: extra kwargs to pass to the model serving class __init__
388
396
  (can be read in the model using .get_param(key) method)
389
397
  """
@@ -419,7 +427,12 @@ class ServingRuntime(RemoteRuntime):
419
427
  if class_name and hasattr(class_name, "to_dict"):
420
428
  if model_path:
421
429
  class_name.model_path = model_path
422
- key, state = params_to_step(class_name, key)
430
+ key, state = params_to_step(
431
+ class_name,
432
+ key,
433
+ model_endpoint_creation_strategy=creation_strategy,
434
+ endpoint_type=schemas.EndpointType.LEAF_EP,
435
+ )
423
436
  else:
424
437
  class_name = class_name or self.spec.default_class
425
438
  if class_name and not isinstance(class_name, str):
@@ -432,12 +445,22 @@ class ServingRuntime(RemoteRuntime):
432
445
  model_path = str(model_path)
433
446
 
434
447
  if model_url:
435
- state = new_remote_endpoint(model_url, **class_args)
448
+ state = new_remote_endpoint(
449
+ model_url,
450
+ creation_strategy=creation_strategy,
451
+ endpoint_type=schemas.EndpointType.LEAF_EP,
452
+ **class_args,
453
+ )
436
454
  else:
437
455
  class_args = deepcopy(class_args)
438
456
  class_args["model_path"] = model_path
439
457
  state = TaskStep(
440
- class_name, class_args, handler=handler, function=child_function
458
+ class_name,
459
+ class_args,
460
+ handler=handler,
461
+ function=child_function,
462
+ model_endpoint_creation_strategy=creation_strategy,
463
+ endpoint_type=schemas.EndpointType.LEAF_EP,
441
464
  )
442
465
 
443
466
  return graph.add_route(key, state)
@@ -581,7 +604,7 @@ class ServingRuntime(RemoteRuntime):
581
604
  project="",
582
605
  tag="",
583
606
  verbose=False,
584
- auth_info: mlrun.common.schemas.AuthInfo = None,
607
+ auth_info: schemas.AuthInfo = None,
585
608
  builder_env: Optional[dict] = None,
586
609
  force_build: bool = False,
587
610
  ):
mlrun/serving/__init__.py CHANGED
@@ -23,6 +23,10 @@ __all__ = [
23
23
  "QueueStep",
24
24
  "ErrorStep",
25
25
  "MonitoringApplicationStep",
26
+ "ModelRunnerStep",
27
+ "ModelRunner",
28
+ "Model",
29
+ "ModelSelector",
26
30
  ]
27
31
 
28
32
  from .routers import ModelRouter, VotingEnsemble # noqa
@@ -33,6 +37,10 @@ from .states import (
33
37
  RouterStep,
34
38
  TaskStep,
35
39
  MonitoringApplicationStep,
40
+ ModelRunnerStep,
41
+ ModelRunner,
42
+ Model,
43
+ ModelSelector,
36
44
  ) # noqa
37
45
  from .v1_serving import MLModelServer, new_v1_model_server # noqa
38
46
  from .v2_serving import V2ModelServer # noqa
mlrun/serving/merger.py CHANGED
@@ -74,7 +74,7 @@ class Merge(storey.Flow):
74
74
  self._queue_len = max_behind or 64 # default queue is 64 entries
75
75
  self._keys_queue = []
76
76
 
77
- def post_init(self, mode="sync"):
77
+ def post_init(self, mode="sync", **kwargs):
78
78
  # auto detect number of uplinks or use user specified value
79
79
  self._uplinks = self.expected_num_events or (
80
80
  len(self._graph_step.after) if self._graph_step else 0
mlrun/serving/remote.py CHANGED
@@ -14,6 +14,7 @@
14
14
  #
15
15
  import asyncio
16
16
  import json
17
+ from copy import copy
17
18
  from typing import Optional
18
19
 
19
20
  import aiohttp
@@ -53,6 +54,7 @@ class RemoteStep(storey.SendToHttp):
53
54
  retries=None,
54
55
  backoff_factor=None,
55
56
  timeout=None,
57
+ headers_expression: Optional[str] = None,
56
58
  **kwargs,
57
59
  ):
58
60
  """class for calling remote endpoints
@@ -86,6 +88,7 @@ class RemoteStep(storey.SendToHttp):
86
88
  :param retries: number of retries (in exponential backoff)
87
89
  :param backoff_factor: A backoff factor in seconds to apply between attempts after the second try
88
90
  :param timeout: How long to wait for the server to send data before giving up, float in seconds
91
+ :param headers_expression: an expression for getting the request headers from the event, e.g. "event['headers']"
89
92
  """
90
93
  # init retry args for storey
91
94
  retries = default_retries if retries is None else retries
@@ -102,6 +105,7 @@ class RemoteStep(storey.SendToHttp):
102
105
  self.url = url
103
106
  self.url_expression = url_expression
104
107
  self.body_expression = body_expression
108
+ self.headers_expression = headers_expression
105
109
  self.headers = headers
106
110
  self.method = method
107
111
  self.return_json = return_json
@@ -114,8 +118,9 @@ class RemoteStep(storey.SendToHttp):
114
118
  self._session = None
115
119
  self._url_function_handler = None
116
120
  self._body_function_handler = None
121
+ self._headers_function_handler = None
117
122
 
118
- def post_init(self, mode="sync"):
123
+ def post_init(self, mode="sync", **kwargs):
119
124
  self._endpoint = self.url
120
125
  if self.url and self.context:
121
126
  self._endpoint = self.context.get_remote_endpoint(self.url).strip("/")
@@ -131,6 +136,12 @@ class RemoteStep(storey.SendToHttp):
131
136
  {"endpoint": self._endpoint, "context": self.context},
132
137
  {},
133
138
  )
139
+ if self.headers_expression:
140
+ self._headers_function_handler = eval(
141
+ "lambda event: " + self.headers_expression,
142
+ {"context": self.context},
143
+ {},
144
+ )
134
145
  elif self.subpath:
135
146
  self._append_event_path = self.subpath == "$path"
136
147
  if not self._append_event_path:
@@ -205,7 +216,10 @@ class RemoteStep(storey.SendToHttp):
205
216
 
206
217
  def _generate_request(self, event, body):
207
218
  method = self.method or event.method or "POST"
208
- headers = self.headers or {}
219
+ if self._headers_function_handler:
220
+ headers = self._headers_function_handler(body)
221
+ else:
222
+ headers = copy(self.headers) or {}
209
223
 
210
224
  if self._url_function_handler:
211
225
  url = self._url_function_handler(body)
@@ -216,10 +230,8 @@ class RemoteStep(storey.SendToHttp):
216
230
  url = url + "/" + striped_path
217
231
  if striped_path:
218
232
  headers[event_path_key] = event.path
219
-
220
233
  if event.id:
221
234
  headers[event_id_key] = event.id
222
-
223
235
  if method == "GET":
224
236
  body = None
225
237
  elif body is not None and not isinstance(body, (str, bytes)):
@@ -334,7 +346,7 @@ class BatchHttpRequests(_ConcurrentJobExecution):
334
346
  async def _cleanup(self):
335
347
  await self._client_session.close()
336
348
 
337
- def post_init(self, mode="sync"):
349
+ def post_init(self, mode="sync", **kwargs):
338
350
  self._endpoint = self.url
339
351
  if self.url and self.context:
340
352
  self._endpoint = self.context.get_remote_endpoint(self.url).strip("/")