mlrun 1.8.0rc43__py3-none-any.whl → 1.8.0rc45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import math
15
- from datetime import datetime, timedelta, timezone
15
+ from datetime import datetime, timedelta
16
16
  from io import StringIO
17
17
  from typing import Callable, Literal, Optional, Union
18
18
 
@@ -33,7 +33,12 @@ _TSDB_BE = "tsdb"
33
33
  _TSDB_RATE = "1/s"
34
34
  _CONTAINER = "users"
35
35
 
36
- V3IO_MEPS_LIMIT = 200
36
+ V3IO_FRAMESD_MEPS_LIMIT = (
37
+ 200 # Maximum number of model endpoints per single request when using V3IO Frames
38
+ )
39
+ V3IO_CLIENT_MEPS_LIMIT = (
40
+ 150 # Maximum number of model endpoints per single request when using V3IO Client
41
+ )
37
42
 
38
43
 
39
44
  def _is_no_schema_error(exc: v3io_frames.Error) -> bool:
@@ -72,6 +77,15 @@ class V3IOTSDBConnector(TSDBConnector):
72
77
  self._frames_client: Optional[v3io_frames.client.ClientBase] = None
73
78
  self._init_tables_path()
74
79
  self._create_table = create_table
80
+ self._v3io_client = None
81
+
82
+ @property
83
+ def v3io_client(self):
84
+ if not self._v3io_client:
85
+ self._v3io_client = mlrun.utils.v3io_clients.get_v3io_client(
86
+ endpoint=mlrun.mlconf.v3io_api, access_key=self._v3io_access_key
87
+ )
88
+ return self._v3io_client
75
89
 
76
90
  @property
77
91
  def frames_client(self) -> v3io_frames.client.ClientBase:
@@ -147,6 +161,21 @@ class V3IOTSDBConnector(TSDBConnector):
147
161
  )
148
162
  self.tables[mm_schemas.V3IOTSDBTables.PREDICTIONS] = monitoring_predictions_path
149
163
 
164
+ # initialize kv table
165
+ last_request_full_table_path = (
166
+ mlrun.mlconf.get_model_monitoring_file_target_path(
167
+ project=self.project,
168
+ kind=mm_schemas.FileTargetKind.LAST_REQUEST,
169
+ )
170
+ )
171
+ (
172
+ _,
173
+ _,
174
+ self.last_request_table,
175
+ ) = mlrun.common.model_monitoring.helpers.parse_model_endpoint_store_prefix(
176
+ last_request_full_table_path
177
+ )
178
+
150
179
  def create_tables(self) -> None:
151
180
  """
152
181
  Create the tables using the TSDB connector. These are the tables that are stored in the V3IO TSDB:
@@ -252,6 +281,16 @@ class V3IOTSDBConnector(TSDBConnector):
252
281
  key=mm_schemas.EventFieldType.ENDPOINT_ID,
253
282
  )
254
283
 
284
+ # Write last request timestamp to KV table
285
+ graph.add_step(
286
+ "storey.NoSqlTarget",
287
+ name="KVLastRequest",
288
+ after="tsdb_predictions",
289
+ table=f"v3io:///users/{self.last_request_table}",
290
+ columns=[EventFieldType.LAST_REQUEST_TIMESTAMP],
291
+ index_cols=[EventFieldType.ENDPOINT_ID],
292
+ )
293
+
255
294
  # Emits the event in window size of events based on sample_window size (10 by default)
256
295
  graph.add_step(
257
296
  "storey.steps.SampleWindow",
@@ -441,8 +480,8 @@ class V3IOTSDBConnector(TSDBConnector):
441
480
  tables = mm_schemas.V3IOTSDBTables.list()
442
481
 
443
482
  # Split the endpoint ids into chunks to avoid exceeding the v3io-engine filter-expression limit
444
- for i in range(0, len(endpoint_ids), V3IO_MEPS_LIMIT):
445
- endpoint_id_chunk = endpoint_ids[i : i + V3IO_MEPS_LIMIT]
483
+ for i in range(0, len(endpoint_ids), V3IO_FRAMESD_MEPS_LIMIT):
484
+ endpoint_id_chunk = endpoint_ids[i : i + V3IO_FRAMESD_MEPS_LIMIT]
446
485
  filter_query = f"endpoint_id IN({str(endpoint_id_chunk)[1:-1]}) "
447
486
  for table in tables:
448
487
  try:
@@ -458,12 +497,31 @@ class V3IOTSDBConnector(TSDBConnector):
458
497
  error=mlrun.errors.err_to_str(e),
459
498
  project=self.project,
460
499
  )
500
+
501
+ # Clean the last request records from the KV table
502
+ self._delete_last_request_records(endpoint_ids=endpoint_ids)
503
+
461
504
  logger.debug(
462
505
  "Deleted all model endpoint resources using the V3IO connector",
463
506
  project=self.project,
464
507
  number_of_endpoints_to_delete=len(endpoint_ids),
465
508
  )
466
509
 
510
+ def _delete_last_request_records(self, endpoint_ids: list[str]):
511
+ for endpoint_id in endpoint_ids:
512
+ try:
513
+ self.v3io_client.kv.delete(
514
+ container=self.container,
515
+ table=self.last_request_table,
516
+ key=endpoint_id,
517
+ )
518
+ except Exception as e:
519
+ logger.warning(
520
+ f"Failed to delete last request record for endpoint '{endpoint_id}'",
521
+ error=mlrun.errors.err_to_str(e),
522
+ project=self.project,
523
+ )
524
+
467
525
  def get_model_endpoint_real_time_metrics(
468
526
  self, endpoint_id: str, metrics: list[str], start: str, end: str
469
527
  ) -> dict[str, list[tuple[str, float]]]:
@@ -631,11 +689,11 @@ class V3IOTSDBConnector(TSDBConnector):
631
689
  if isinstance(endpoint_id, str):
632
690
  return f"endpoint_id=='{endpoint_id}'"
633
691
  elif isinstance(endpoint_id, list):
634
- if len(endpoint_id) > V3IO_MEPS_LIMIT:
692
+ if len(endpoint_id) > V3IO_FRAMESD_MEPS_LIMIT:
635
693
  logger.info(
636
694
  "The number of endpoint ids exceeds the v3io-engine filter-expression limit, "
637
695
  "retrieving all the model endpoints from the db.",
638
- limit=V3IO_MEPS_LIMIT,
696
+ limit=V3IO_FRAMESD_MEPS_LIMIT,
639
697
  amount=len(endpoint_id),
640
698
  )
641
699
  return None
@@ -826,41 +884,51 @@ class V3IOTSDBConnector(TSDBConnector):
826
884
  endpoint_ids: Union[str, list[str]],
827
885
  start: Optional[datetime] = None,
828
886
  end: Optional[datetime] = None,
829
- get_raw: bool = False,
830
- ) -> Union[pd.DataFrame, list[v3io_frames.client.RawFrame]]:
831
- filter_query = self._get_endpoint_filter(endpoint_id=endpoint_ids)
832
- start, end = self._get_start_end(start, end)
833
-
834
- res = self._get_records(
835
- table=mm_schemas.V3IOTSDBTables.PREDICTIONS,
836
- start=start,
837
- end=end,
838
- filter_query=filter_query,
839
- agg_funcs=["last"],
840
- get_raw=get_raw,
841
- )
887
+ ) -> dict[str, float]:
888
+ # Get the last request timestamp for each endpoint from the KV table.
889
+ # The result of the query is a list of dictionaries,
890
+ # each dictionary contains the endpoint id and the last request timestamp.
891
+ last_request_timestamps = {}
892
+ if isinstance(endpoint_ids, str):
893
+ endpoint_ids = [endpoint_ids]
842
894
 
843
- if get_raw:
844
- return res
895
+ try:
896
+ if len(endpoint_ids) > V3IO_CLIENT_MEPS_LIMIT:
897
+ logger.warning(
898
+ "The number of endpoint ids exceeds the v3io-engine filter-expression limit, "
899
+ "retrieving last request for all the model endpoints from the KV table.",
900
+ limit=V3IO_CLIENT_MEPS_LIMIT,
901
+ amount=len(endpoint_ids),
902
+ )
845
903
 
846
- df = res
847
- if not df.empty:
848
- df.rename(
849
- columns={
850
- f"last({mm_schemas.EventFieldType.LAST_REQUEST_TIMESTAMP})": mm_schemas.EventFieldType.LAST_REQUEST,
851
- f"last({mm_schemas.EventFieldType.LATENCY})": f"last_{mm_schemas.EventFieldType.LATENCY}",
852
- },
853
- inplace=True,
854
- )
855
- df[mm_schemas.EventFieldType.LAST_REQUEST] = df[
856
- mm_schemas.EventFieldType.LAST_REQUEST
857
- ].map(
858
- lambda last_request: datetime.fromtimestamp(
859
- last_request, tz=timezone.utc
904
+ res = self.v3io_client.kv.new_cursor(
905
+ container=self.container,
906
+ table_path=self.last_request_table,
907
+ ).all()
908
+ last_request_timestamps.update(
909
+ {d["__name"]: d["last_request_timestamp"] for d in res}
860
910
  )
911
+ else:
912
+ filter_expression = " OR ".join(
913
+ [f"__name=='{endpoint_id}'" for endpoint_id in endpoint_ids]
914
+ )
915
+ res = self.v3io_client.kv.new_cursor(
916
+ container=self.container,
917
+ table_path=self.last_request_table,
918
+ filter_expression=filter_expression,
919
+ ).all()
920
+ last_request_timestamps.update(
921
+ {d["__name"]: d["last_request_timestamp"] for d in res}
922
+ )
923
+ except Exception as e:
924
+ logger.warning(
925
+ "Failed to get last request timestamp from V3IO KV table.",
926
+ err=mlrun.errors.err_to_str(e),
927
+ project=self.project,
928
+ table=self.last_request_table,
861
929
  )
862
930
 
863
- return df.reset_index(drop=True)
931
+ return last_request_timestamps
864
932
 
865
933
  def get_drift_status(
866
934
  self,
@@ -1037,7 +1105,6 @@ class V3IOTSDBConnector(TSDBConnector):
1037
1105
  model_endpoint_objects_by_uid[uid] = model_endpoint_object
1038
1106
 
1039
1107
  error_count_res = self.get_error_count(endpoint_ids=uids, get_raw=True)
1040
- last_request_res = self.get_last_request(endpoint_ids=uids, get_raw=True)
1041
1108
  avg_latency_res = self.get_avg_latency(endpoint_ids=uids, get_raw=True)
1042
1109
  drift_status_res = self.get_drift_status(endpoint_ids=uids, get_raw=True)
1043
1110
 
@@ -1060,11 +1127,7 @@ class V3IOTSDBConnector(TSDBConnector):
1060
1127
  "count(error_count)",
1061
1128
  error_count_res,
1062
1129
  )
1063
- add_metric(
1064
- "last_request",
1065
- "last(last_request_timestamp)",
1066
- last_request_res,
1067
- )
1130
+
1068
1131
  add_metric(
1069
1132
  "avg_latency",
1070
1133
  "avg(latency)",
@@ -1075,4 +1138,23 @@ class V3IOTSDBConnector(TSDBConnector):
1075
1138
  "max(result_status)",
1076
1139
  drift_status_res,
1077
1140
  )
1141
+
1142
+ self._enrich_mep_with_last_request(
1143
+ model_endpoint_objects_by_uid=model_endpoint_objects_by_uid
1144
+ )
1145
+
1078
1146
  return list(model_endpoint_objects_by_uid.values())
1147
+
1148
+ def _enrich_mep_with_last_request(
1149
+ self,
1150
+ model_endpoint_objects_by_uid: dict[str, mlrun.common.schemas.ModelEndpoint],
1151
+ ):
1152
+ last_request_dictionary = self.get_last_request(
1153
+ endpoint_ids=list(model_endpoint_objects_by_uid.keys())
1154
+ )
1155
+ for uid, mep in model_endpoint_objects_by_uid.items():
1156
+ # Set the last request timestamp to the MEP object. If not found, keep the existing value from the
1157
+ # DB (relevant for batch EP).
1158
+ mep.status.last_request = last_request_dictionary.get(
1159
+ uid, mep.status.last_request
1160
+ )
@@ -180,7 +180,7 @@ class ModelMonitoringWriter(StepToDict):
180
180
  data, timestamp
181
181
  )
182
182
  logger.info(
183
- "Updating the model endpoint statistics",
183
+ "Updated the model endpoint statistics",
184
184
  endpoint_id=endpoint_id,
185
185
  stats_kind=stat_kind,
186
186
  )
mlrun/projects/project.py CHANGED
@@ -2451,7 +2451,22 @@ class MlrunProject(ModelObj):
2451
2451
  :param image: The image of the model monitoring controller, writer, monitoring
2452
2452
  stream & histogram data drift functions, which are real time nuclio
2453
2453
  functions. By default, the image is mlrun/mlrun.
2454
- :param deploy_histogram_data_drift_app: If true, deploy the default histogram-based data drift application.
2454
+ :param deploy_histogram_data_drift_app: If true, deploy the default histogram-based data drift application:
2455
+ :py:class:`~mlrun.model_monitoring.applications.histogram_data_drift.HistogramDataDriftApplication`.
2456
+ If false, and you want to deploy the histogram data drift application
2457
+ afterwards, you may use the
2458
+ :py:func:`~set_model_monitoring_function` method::
2459
+
2460
+ import mlrun.model_monitoring.applications.histogram_data_drift as histogram_data_drift
2461
+
2462
+ hist_app = project.set_model_monitoring_function(
2463
+ name=histogram_data_drift.HistogramDataDriftApplicationConstants.NAME, # keep the default name
2464
+ func=histogram_data_drift.__file__,
2465
+ application_class=histogram_data_drift.HistogramDataDriftApplication.__name__,
2466
+ )
2467
+
2468
+ project.deploy_function(hist_app)
2469
+
2455
2470
  :param wait_for_deployment: If true, return only after the deployment is done on the backend.
2456
2471
  Otherwise, deploy the model monitoring infrastructure on the
2457
2472
  background, including the histogram data drift app if selected.
@@ -2488,30 +2503,6 @@ class MlrunProject(ModelObj):
2488
2503
  )
2489
2504
  self._wait_for_functions_deployment(deployment_functions)
2490
2505
 
2491
- def deploy_histogram_data_drift_app(
2492
- self,
2493
- *,
2494
- image: str = "mlrun/mlrun",
2495
- db: Optional[mlrun.db.RunDBInterface] = None,
2496
- wait_for_deployment: bool = False,
2497
- ) -> None:
2498
- """
2499
- Deploy the histogram data drift application.
2500
-
2501
- :param image: The image on which the application will run.
2502
- :param db: An optional DB object.
2503
- :param wait_for_deployment: If true, return only after the deployment is done on the backend.
2504
- Otherwise, deploy the application on the background.
2505
- """
2506
- if db is None:
2507
- db = mlrun.db.get_run_db(secrets=self._secrets)
2508
- db.deploy_histogram_data_drift_app(project=self.name, image=image)
2509
-
2510
- if wait_for_deployment:
2511
- self._wait_for_functions_deployment(
2512
- [mm_constants.HistogramDataDriftApplicationConstants.NAME]
2513
- )
2514
-
2515
2506
  def update_model_monitoring_controller(
2516
2507
  self,
2517
2508
  base_period: int = 10,
@@ -5034,14 +5025,20 @@ class MlrunProject(ModelObj):
5034
5025
  db = mlrun.db.get_run_db(secrets=self._secrets)
5035
5026
  return db.get_alert_config(alert_name, self.metadata.name)
5036
5027
 
5037
- def list_alerts_configs(self) -> list[AlertConfig]:
5028
+ def list_alerts_configs(
5029
+ self, limit: Optional[int] = None, offset: Optional[int] = None
5030
+ ) -> list[AlertConfig]:
5038
5031
  """
5039
5032
  Retrieve list of alerts of a project.
5040
5033
 
5034
+ :param limit: The maximum number of alerts to return.
5035
+ Defaults to `mlconf.alerts.default_list_alert_configs_limit` if not provided.
5036
+ :param offset: The number of alerts to skip before starting to collect alerts.
5037
+
5041
5038
  :return: All the alerts objects of the project.
5042
5039
  """
5043
5040
  db = mlrun.db.get_run_db(secrets=self._secrets)
5044
- return db.list_alerts_configs(self.metadata.name)
5041
+ return db.list_alerts_configs(self.metadata.name, limit=limit, offset=offset)
5045
5042
 
5046
5043
  def delete_alert_config(
5047
5044
  self, alert_data: AlertConfig = None, alert_name: Optional[str] = None
@@ -4,7 +4,7 @@
4
4
  // you may not use this file except in compliance with the License.
5
5
  // You may obtain a copy of the License at
6
6
  //
7
- // http://www.apache.org/licenses/LICENSE-2.0
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
8
  //
9
9
  // Unless required by applicable law or agreed to in writing, software
10
10
  // distributed under the License is distributed on an "AS IS" BASIS,
@@ -14,82 +14,84 @@
14
14
  package main
15
15
 
16
16
  import (
17
- "bytes"
18
- "fmt"
19
- "net/http"
20
- "net/http/httptest"
21
- "net/http/httputil"
22
- "net/url"
23
- "os"
24
- "strings"
17
+ "bytes"
18
+ "fmt"
19
+ "net/http"
20
+ "net/http/httptest"
21
+ "net/http/httputil"
22
+ "net/url"
23
+ "os"
24
+ "strings"
25
25
 
26
- nuclio "github.com/nuclio/nuclio-sdk-go"
26
+ nuclio "github.com/nuclio/nuclio-sdk-go"
27
27
  )
28
28
 
29
29
  func Handler(context *nuclio.Context, event nuclio.Event) (interface{}, error) {
30
- reverseProxy := context.UserData.(map[string]interface{})["reverseProxy"].(*httputil.ReverseProxy)
31
- sidecarUrl := context.UserData.(map[string]interface{})["server"].(string)
30
+ reverseProxy := context.UserData.(map[string]interface{})["reverseProxy"].(*httputil.ReverseProxy)
31
+ sidecarUrl := context.UserData.(map[string]interface{})["server"].(string)
32
32
 
33
- // populate reverse proxy http request
34
- httpRequest, err := http.NewRequest(event.GetMethod(), event.GetPath(), bytes.NewReader(event.GetBody()))
35
- if err != nil {
36
- context.Logger.ErrorWith("Failed to create a reverse proxy request")
37
- return nil, err
38
- }
39
- for k, v := range event.GetHeaders() {
40
- httpRequest.Header[k] = []string{v.(string)}
41
- }
33
+ // populate reverse proxy http request
34
+ httpRequest, err := http.NewRequest(event.GetMethod(), event.GetPath(), bytes.NewReader(event.GetBody()))
35
+ if err != nil {
36
+ context.Logger.ErrorWith("Failed to create a reverse proxy request")
37
+ return nil, err
38
+ }
39
+ for k, v := range event.GetHeaders() {
40
+ httpRequest.Header[k] = []string{v.(string)}
41
+ }
42
42
 
43
- // populate query params
44
- query := httpRequest.URL.Query()
45
- for k, v := range event.GetFields() {
46
- query.Set(k, v.(string))
47
- }
48
- httpRequest.URL.RawQuery = query.Encode()
43
+ // populate query params
44
+ query := httpRequest.URL.Query()
45
+ for k, v := range event.GetFields() {
46
+ query.Set(k, v.(string))
47
+ }
48
+ httpRequest.URL.RawQuery = query.Encode()
49
49
 
50
- recorder := httptest.NewRecorder()
51
- reverseProxy.ServeHTTP(recorder, httpRequest)
50
+ recorder := httptest.NewRecorder()
51
+ reverseProxy.ServeHTTP(recorder, httpRequest)
52
52
 
53
- // send request to sidecar
54
- context.Logger.DebugWith("Forwarding request to sidecar", "sidecarUrl", sidecarUrl, "query", httpRequest.URL.Query())
55
- response := recorder.Result()
53
+ // send request to sidecar
54
+ context.Logger.DebugWith("Forwarding request to sidecar",
55
+ "sidecarUrl", sidecarUrl,
56
+ "method", event.GetMethod())
57
+ response := recorder.Result()
56
58
 
57
- headers := make(map[string]interface{})
58
- for key, value := range response.Header {
59
- headers[key] = value[0]
60
- }
59
+ headers := make(map[string]interface{})
60
+ for key, value := range response.Header {
61
+ headers[key] = value[0]
62
+ }
61
63
 
62
- // let the processor calculate the content length
63
- delete(headers, "Content-Length")
64
- return nuclio.Response{
65
- StatusCode: response.StatusCode,
66
- Body: recorder.Body.Bytes(),
67
- ContentType: response.Header.Get("Content-Type"),
68
- Headers: headers,
69
- }, nil
64
+ // let the processor calculate the content length
65
+ delete(headers, "Content-Length")
66
+ return nuclio.Response{
67
+ StatusCode: response.StatusCode,
68
+ Body: recorder.Body.Bytes(),
69
+ ContentType: response.Header.Get("Content-Type"),
70
+ Headers: headers,
71
+ }, nil
70
72
  }
71
73
 
72
74
  func InitContext(context *nuclio.Context) error {
73
- sidecarHost := os.Getenv("SIDECAR_HOST")
74
- sidecarPort := os.Getenv("SIDECAR_PORT")
75
- if sidecarHost == "" {
76
- sidecarHost = "http://localhost"
77
- } else if !strings.Contains(sidecarHost, "://") {
78
- sidecarHost = fmt.Sprintf("http://%s", sidecarHost)
79
- }
75
+ sidecarHost := os.Getenv("SIDECAR_HOST")
76
+ sidecarPort := os.Getenv("SIDECAR_PORT")
77
+ if sidecarHost == "" {
78
+ sidecarHost = "http://localhost"
79
+ } else if !strings.Contains(sidecarHost, "://") {
80
+ sidecarHost = fmt.Sprintf("http://%s", sidecarHost)
81
+ }
80
82
 
81
- // url for request forwarding
82
- sidecarUrl := fmt.Sprintf("%s:%s", sidecarHost, sidecarPort)
83
- parsedURL, err := url.Parse(sidecarUrl)
84
- if err != nil {
85
- context.Logger.ErrorWith("Failed to parse sidecar url", "sidecarUrl", sidecarUrl)
86
- return err
87
- }
88
- reverseProxy := httputil.NewSingleHostReverseProxy(parsedURL)
83
+ // url for request forwarding
84
+ sidecarUrl := fmt.Sprintf("%s:%s", sidecarHost, sidecarPort)
85
+ parsedURL, err := url.Parse(sidecarUrl)
86
+ if err != nil {
87
+ context.Logger.ErrorWith("Failed to parse sidecar url", "sidecarUrl", sidecarUrl)
88
+ return err
89
+ }
90
+ reverseProxy := httputil.NewSingleHostReverseProxy(parsedURL)
89
91
 
90
- context.UserData = map[string]interface{}{
91
- "server": sidecarUrl,
92
- "reverseProxy": reverseProxy,
93
- }
94
- return nil
92
+ context.UserData = map[string]interface{}{
93
+ "server": sidecarUrl,
94
+ "reverseProxy": reverseProxy,
95
+ }
96
+ return nil
95
97
  }
mlrun/serving/states.py CHANGED
@@ -959,7 +959,7 @@ class ModelRunner(storey.ParallelExecution):
959
959
  return self.model_selector.select(event, models)
960
960
 
961
961
 
962
- class ModelRunnerStep(TaskStep):
962
+ class ModelRunnerStep(TaskStep, StepToDict):
963
963
  """
964
964
  Runs multiple Models on each event.
965
965
 
@@ -981,29 +981,41 @@ class ModelRunnerStep(TaskStep):
981
981
  model_selector: Optional[Union[str, ModelSelector]] = None,
982
982
  **kwargs,
983
983
  ):
984
- self._models = []
985
984
  super().__init__(
986
985
  *args,
987
986
  class_name="mlrun.serving.ModelRunner",
988
- class_args=dict(runnables=self._models, model_selector=model_selector),
987
+ class_args=dict(model_selector=model_selector),
989
988
  **kwargs,
990
989
  )
991
990
 
992
- def add_model(self, model: Model) -> None:
993
- """Add a Model to this ModelRunner."""
994
- self._models.append(model)
991
+ def add_model(self, model: Union[str, Model], **model_parameters) -> None:
992
+ """
993
+ Add a Model to this ModelRunner.
994
+
995
+ :param model: Model class name or object
996
+ :param model_parameters: Parameters for model instantiation
997
+ """
998
+ models = self.class_args.get("models", [])
999
+ models.append((model, model_parameters))
1000
+ self.class_args["models"] = models
995
1001
 
996
1002
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
997
1003
  model_selector = self.class_args.get("model_selector")
1004
+ models = self.class_args.get("models")
998
1005
  if isinstance(model_selector, str):
999
1006
  model_selector = get_class(model_selector, namespace)()
1007
+ model_objects = []
1008
+ for model, model_params in models:
1009
+ if not isinstance(model, Model):
1010
+ model = get_class(model, namespace)(**model_params)
1011
+ model_objects.append(model)
1000
1012
  self._async_object = ModelRunner(
1001
- self.class_args.get("runnables"),
1002
1013
  model_selector=model_selector,
1014
+ runnables=model_objects,
1003
1015
  )
1004
1016
 
1005
1017
 
1006
- class QueueStep(BaseStep):
1018
+ class QueueStep(BaseStep, StepToDict):
1007
1019
  """queue step, implement an async queue or represent a stream"""
1008
1020
 
1009
1021
  kind = "queue"
@@ -1799,21 +1811,13 @@ def params_to_step(
1799
1811
 
1800
1812
  class_args = class_args or {}
1801
1813
 
1802
- if class_name and hasattr(class_name, "to_dict"):
1803
- struct = class_name.to_dict()
1804
- kind = struct.get("kind", StepKinds.task)
1805
- name = name or struct.get("name", struct.get("class_name"))
1806
- cls = classes_map.get(kind, RootFlowStep)
1807
- step = cls.from_dict(struct)
1808
- step.function = function
1809
- step.full_event = full_event or step.full_event
1810
- step.input_path = input_path or step.input_path
1811
- step.result_path = result_path or step.result_path
1812
- if kind == StepKinds.task:
1813
- step.model_endpoint_creation_strategy = model_endpoint_creation_strategy
1814
- step.endpoint_type = endpoint_type
1814
+ if isinstance(class_name, QueueStep):
1815
+ if not name or class_name.name:
1816
+ raise MLRunInvalidArgumentError("queue name must be specified")
1817
+
1818
+ step = class_name
1815
1819
 
1816
- elif class_name and class_name in queue_class_names:
1820
+ elif class_name in queue_class_names:
1817
1821
  if "path" not in class_args:
1818
1822
  raise MLRunInvalidArgumentError(
1819
1823
  "path=<stream path or None> must be specified for queues"
@@ -1826,6 +1830,20 @@ def params_to_step(
1826
1830
  class_args["full_event"] = full_event
1827
1831
  step = QueueStep(name, **class_args)
1828
1832
 
1833
+ elif class_name and hasattr(class_name, "to_dict"):
1834
+ struct = class_name.to_dict()
1835
+ kind = struct.get("kind", StepKinds.task)
1836
+ name = name or struct.get("name", struct.get("class_name"))
1837
+ cls = classes_map.get(kind, RootFlowStep)
1838
+ step = cls.from_dict(struct)
1839
+ step.function = function
1840
+ step.full_event = full_event or step.full_event
1841
+ step.input_path = input_path or step.input_path
1842
+ step.result_path = result_path or step.result_path
1843
+ if kind == StepKinds.task:
1844
+ step.model_endpoint_creation_strategy = model_endpoint_creation_strategy
1845
+ step.endpoint_type = endpoint_type
1846
+
1829
1847
  elif class_name and class_name.startswith("*"):
1830
1848
  routes = class_args.get("routes", None)
1831
1849
  class_name = class_name[1:]
@@ -1,4 +1,4 @@
1
1
  {
2
- "git_commit": "4ffe55c0a4818e4a58b7a80b4a05a1932e4ff99b",
3
- "version": "1.8.0-rc43"
2
+ "git_commit": "b434e35c26bb66407a60bedddb7a9af71141902b",
3
+ "version": "1.8.0-rc45"
4
4
  }
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: mlrun
3
- Version: 1.8.0rc43
3
+ Version: 1.8.0rc45
4
4
  Summary: Tracking and config of machine learning runs
5
5
  Home-page: https://github.com/mlrun/mlrun
6
6
  Author: Yaron Haviv
@@ -240,6 +240,7 @@ Dynamic: description-content-type
240
240
  Dynamic: home-page
241
241
  Dynamic: keywords
242
242
  Dynamic: license
243
+ Dynamic: license-file
243
244
  Dynamic: provides-extra
244
245
  Dynamic: requires-dist
245
246
  Dynamic: requires-python