mlrun 1.8.0rc4__py3-none-any.whl → 1.8.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (75) hide show
  1. mlrun/__init__.py +5 -3
  2. mlrun/alerts/alert.py +129 -2
  3. mlrun/artifacts/__init__.py +1 -1
  4. mlrun/artifacts/base.py +12 -1
  5. mlrun/artifacts/document.py +59 -38
  6. mlrun/common/constants.py +1 -0
  7. mlrun/common/model_monitoring/__init__.py +0 -2
  8. mlrun/common/model_monitoring/helpers.py +0 -28
  9. mlrun/common/schemas/__init__.py +2 -4
  10. mlrun/common/schemas/alert.py +80 -1
  11. mlrun/common/schemas/artifact.py +4 -0
  12. mlrun/common/schemas/client_spec.py +0 -1
  13. mlrun/common/schemas/model_monitoring/__init__.py +0 -6
  14. mlrun/common/schemas/model_monitoring/constants.py +11 -9
  15. mlrun/common/schemas/model_monitoring/model_endpoints.py +77 -149
  16. mlrun/common/schemas/notification.py +6 -0
  17. mlrun/common/schemas/project.py +3 -0
  18. mlrun/config.py +2 -3
  19. mlrun/datastore/datastore_profile.py +57 -17
  20. mlrun/datastore/sources.py +1 -2
  21. mlrun/datastore/vectorstore.py +67 -59
  22. mlrun/db/base.py +29 -19
  23. mlrun/db/factory.py +0 -3
  24. mlrun/db/httpdb.py +224 -161
  25. mlrun/db/nopdb.py +36 -17
  26. mlrun/execution.py +46 -32
  27. mlrun/feature_store/api.py +1 -0
  28. mlrun/model.py +7 -0
  29. mlrun/model_monitoring/__init__.py +3 -2
  30. mlrun/model_monitoring/api.py +55 -53
  31. mlrun/model_monitoring/applications/_application_steps.py +4 -2
  32. mlrun/model_monitoring/applications/base.py +165 -6
  33. mlrun/model_monitoring/applications/context.py +88 -37
  34. mlrun/model_monitoring/applications/evidently_base.py +0 -1
  35. mlrun/model_monitoring/applications/histogram_data_drift.py +3 -7
  36. mlrun/model_monitoring/controller.py +43 -37
  37. mlrun/model_monitoring/db/__init__.py +0 -2
  38. mlrun/model_monitoring/db/tsdb/base.py +2 -1
  39. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +2 -1
  40. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +43 -0
  41. mlrun/model_monitoring/helpers.py +79 -66
  42. mlrun/model_monitoring/stream_processing.py +83 -270
  43. mlrun/model_monitoring/writer.py +1 -10
  44. mlrun/projects/pipelines.py +37 -1
  45. mlrun/projects/project.py +171 -74
  46. mlrun/run.py +40 -0
  47. mlrun/runtimes/nuclio/function.py +7 -6
  48. mlrun/runtimes/nuclio/serving.py +9 -2
  49. mlrun/serving/routers.py +158 -145
  50. mlrun/serving/server.py +6 -0
  51. mlrun/serving/states.py +21 -7
  52. mlrun/serving/v2_serving.py +70 -61
  53. mlrun/utils/helpers.py +14 -30
  54. mlrun/utils/notifications/notification/mail.py +36 -9
  55. mlrun/utils/notifications/notification_pusher.py +43 -18
  56. mlrun/utils/version/version.json +2 -2
  57. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc7.dist-info}/METADATA +5 -4
  58. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc7.dist-info}/RECORD +62 -75
  59. mlrun/common/schemas/model_monitoring/model_endpoint_v2.py +0 -149
  60. mlrun/model_monitoring/db/stores/__init__.py +0 -136
  61. mlrun/model_monitoring/db/stores/base/__init__.py +0 -15
  62. mlrun/model_monitoring/db/stores/base/store.py +0 -154
  63. mlrun/model_monitoring/db/stores/sqldb/__init__.py +0 -13
  64. mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +0 -46
  65. mlrun/model_monitoring/db/stores/sqldb/models/base.py +0 -93
  66. mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +0 -47
  67. mlrun/model_monitoring/db/stores/sqldb/models/sqlite.py +0 -25
  68. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +0 -408
  69. mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +0 -13
  70. mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +0 -464
  71. mlrun/model_monitoring/model_endpoint.py +0 -120
  72. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc7.dist-info}/LICENSE +0 -0
  73. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc7.dist-info}/WHEEL +0 -0
  74. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc7.dist-info}/entry_points.txt +0 -0
  75. {mlrun-1.8.0rc4.dist-info → mlrun-1.8.0rc7.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,6 @@
14
14
 
15
15
  import collections
16
16
  import datetime
17
- import json
18
17
  import os
19
18
  import typing
20
19
 
@@ -32,13 +31,10 @@ import mlrun.utils
32
31
  from mlrun.common.schemas.model_monitoring.constants import (
33
32
  EndpointType,
34
33
  EventFieldType,
35
- EventKeyMetrics,
36
- EventLiveStats,
37
34
  FileTargetKind,
38
- ModelEndpointTarget,
39
35
  ProjectSecretKeys,
40
36
  )
41
- from mlrun.model_monitoring.db import StoreBase, TSDBConnector
37
+ from mlrun.model_monitoring.db import TSDBConnector
42
38
  from mlrun.utils import logger
43
39
 
44
40
 
@@ -102,18 +98,6 @@ class EventStreamProcessor:
102
98
  v3io_access_key=self.model_monitoring_access_key, v3io_api=self.v3io_api
103
99
  )
104
100
 
105
- # KV path
106
- kv_path = mlrun.mlconf.get_model_monitoring_file_target_path(
107
- project=self.project, kind=FileTargetKind.ENDPOINTS
108
- )
109
- (
110
- _,
111
- self.kv_container,
112
- self.kv_path,
113
- ) = mlrun.common.model_monitoring.helpers.parse_model_endpoint_store_prefix(
114
- kv_path
115
- )
116
-
117
101
  # TSDB path and configurations
118
102
  tsdb_path = mlrun.mlconf.get_model_monitoring_file_target_path(
119
103
  project=self.project, kind=FileTargetKind.EVENTS
@@ -134,7 +118,6 @@ class EventStreamProcessor:
134
118
  self,
135
119
  fn: mlrun.runtimes.ServingRuntime,
136
120
  tsdb_connector: TSDBConnector,
137
- endpoint_store: StoreBase,
138
121
  ) -> None:
139
122
  """
140
123
  Apply monitoring serving graph to a given serving function. The following serving graph includes about 4 main
@@ -163,31 +146,23 @@ class EventStreamProcessor:
163
146
 
164
147
  :param fn: A serving function.
165
148
  :param tsdb_connector: Time series database connector.
166
- :param endpoint_store: KV/SQL store used for endpoint data.
167
149
  """
168
150
 
169
151
  graph = typing.cast(
170
152
  mlrun.serving.states.RootFlowStep,
171
153
  fn.set_topology(mlrun.serving.states.StepKinds.flow),
172
154
  )
173
- graph.add_step(
174
- "ExtractEndpointID",
175
- "extract_endpoint",
176
- full_event=True,
177
- )
178
155
 
179
156
  # split the graph between event with error vs valid event
180
157
  graph.add_step(
181
158
  "storey.Filter",
182
159
  "FilterError",
183
- after="extract_endpoint",
184
160
  _fn="(event.get('error') is None)",
185
161
  )
186
162
 
187
163
  graph.add_step(
188
164
  "storey.Filter",
189
165
  "ForwardError",
190
- after="extract_endpoint",
191
166
  _fn="(event.get('error') is not None)",
192
167
  )
193
168
 
@@ -199,7 +174,7 @@ class EventStreamProcessor:
199
174
  def apply_process_endpoint_event():
200
175
  graph.add_step(
201
176
  "ProcessEndpointEvent",
202
- after="extract_endpoint", # TODO: change this to FilterError in ML-7456
177
+ after="FilterError",
203
178
  full_event=True,
204
179
  project=self.project,
205
180
  )
@@ -234,79 +209,11 @@ class EventStreamProcessor:
234
209
  )
235
210
 
236
211
  apply_map_feature_names()
237
-
238
- # Calculate number of predictions and average latency
239
- def apply_storey_aggregations():
240
- # Calculate number of predictions for each window (5 min and 1 hour by default)
241
- graph.add_step(
242
- class_name="storey.AggregateByKey",
243
- aggregates=[
244
- {
245
- "name": EventFieldType.LATENCY,
246
- "column": EventFieldType.LATENCY,
247
- "operations": ["count", "avg"],
248
- "windows": self.aggregate_windows,
249
- "period": self.aggregate_period,
250
- }
251
- ],
252
- name=EventFieldType.LATENCY,
253
- after="MapFeatureNames",
254
- step_name="Aggregates",
255
- table=".",
256
- key_field=EventFieldType.ENDPOINT_ID,
257
- )
258
- # Calculate average latency time for each window (5 min and 1 hour by default)
259
- graph.add_step(
260
- class_name="storey.Rename",
261
- mapping={
262
- "latency_count_5m": EventLiveStats.PREDICTIONS_COUNT_5M,
263
- "latency_count_1h": EventLiveStats.PREDICTIONS_COUNT_1H,
264
- },
265
- name="Rename",
266
- after=EventFieldType.LATENCY,
267
- )
268
-
269
- apply_storey_aggregations()
270
-
271
- # KV/SQL branch
272
- # Filter relevant keys from the event before writing the data into the database table
273
- def apply_process_before_endpoint_update():
274
- graph.add_step(
275
- "ProcessBeforeEndpointUpdate",
276
- name="ProcessBeforeEndpointUpdate",
277
- after="Rename",
278
- )
279
-
280
- apply_process_before_endpoint_update()
281
-
282
- # Write the filtered event to KV/SQL table. At this point, the serving graph updates the stats
283
- # about average latency and the amount of predictions over time
284
- def apply_update_endpoint():
285
- graph.add_step(
286
- "UpdateEndpoint",
287
- name="UpdateEndpoint",
288
- after="ProcessBeforeEndpointUpdate",
289
- project=self.project,
290
- )
291
-
292
- apply_update_endpoint()
293
-
294
- # (only for V3IO KV target) - Apply infer_schema on the model endpoints table for generating schema file
295
- # which will be used by Grafana monitoring dashboards
296
- def apply_infer_schema():
297
- graph.add_step(
298
- "InferSchema",
299
- name="InferSchema",
300
- after="UpdateEndpoint",
301
- v3io_framesd=self.v3io_framesd,
302
- container=self.kv_container,
303
- table=self.kv_path,
304
- )
305
-
306
- if endpoint_store.type == ModelEndpointTarget.V3IO_NOSQL:
307
- apply_infer_schema()
308
-
309
- tsdb_connector.apply_monitoring_stream_steps(graph=graph)
212
+ tsdb_connector.apply_monitoring_stream_steps(
213
+ graph=graph,
214
+ aggregate_windows=self.aggregate_windows,
215
+ aggregate_period=self.aggregate_period,
216
+ )
310
217
 
311
218
  # Parquet branch
312
219
  # Filter and validate different keys before writing the data to Parquet target
@@ -342,91 +249,6 @@ class EventStreamProcessor:
342
249
  apply_parquet_target()
343
250
 
344
251
 
345
- class ProcessBeforeEndpointUpdate(mlrun.feature_store.steps.MapClass):
346
- def __init__(self, **kwargs):
347
- """
348
- Filter relevant keys from the event before writing the data to database table (in EndpointUpdate step).
349
- Note that in the endpoint table we only keep metadata (function_uri, model_class, etc.) and stats about the
350
- average latency and the number of predictions (per 5min and 1hour).
351
-
352
- :returns: A filtered event as a dictionary which will be written to the endpoint table in the next step.
353
- """
354
- super().__init__(**kwargs)
355
-
356
- def do(self, event):
357
- # Compute prediction per second
358
- event[EventLiveStats.PREDICTIONS_PER_SECOND] = (
359
- float(event[EventLiveStats.PREDICTIONS_COUNT_5M]) / 300
360
- )
361
- # Filter relevant keys
362
- e = {
363
- k: event[k]
364
- for k in [
365
- EventFieldType.FUNCTION_URI,
366
- EventFieldType.MODEL,
367
- EventFieldType.MODEL_CLASS,
368
- EventFieldType.ENDPOINT_ID,
369
- EventFieldType.LABELS,
370
- EventFieldType.FIRST_REQUEST,
371
- EventFieldType.LAST_REQUEST,
372
- EventFieldType.ERROR_COUNT,
373
- ]
374
- }
375
-
376
- # Add generic metrics statistics
377
- generic_metrics = {
378
- k: event[k]
379
- for k in [
380
- EventLiveStats.LATENCY_AVG_5M,
381
- EventLiveStats.LATENCY_AVG_1H,
382
- EventLiveStats.PREDICTIONS_PER_SECOND,
383
- EventLiveStats.PREDICTIONS_COUNT_5M,
384
- EventLiveStats.PREDICTIONS_COUNT_1H,
385
- ]
386
- }
387
-
388
- e[EventFieldType.METRICS] = json.dumps(
389
- {EventKeyMetrics.GENERIC: generic_metrics}
390
- )
391
-
392
- # Write labels as json string as required by the DB format
393
- e[EventFieldType.LABELS] = json.dumps(e[EventFieldType.LABELS])
394
-
395
- return e
396
-
397
-
398
- class ExtractEndpointID(mlrun.feature_store.steps.MapClass):
399
- def __init__(self, **kwargs) -> None:
400
- """
401
- Generate the model endpoint ID based on the event parameters and attach it to the event.
402
- """
403
- super().__init__(**kwargs)
404
-
405
- def do(self, full_event) -> typing.Union[storey.Event, None]:
406
- # Getting model version and function uri from event
407
- # and use them for retrieving the endpoint_id
408
- function_uri = full_event.body.get(EventFieldType.FUNCTION_URI)
409
- if not is_not_none(function_uri, [EventFieldType.FUNCTION_URI]):
410
- return None
411
-
412
- model = full_event.body.get(EventFieldType.MODEL)
413
- if not is_not_none(model, [EventFieldType.MODEL]):
414
- return None
415
-
416
- version = full_event.body.get(EventFieldType.VERSION)
417
- versioned_model = f"{model}:{version}" if version else f"{model}:latest"
418
-
419
- endpoint_id = mlrun.common.model_monitoring.create_model_endpoint_uid(
420
- function_uri=function_uri,
421
- versioned_model=versioned_model,
422
- )
423
-
424
- endpoint_id = str(endpoint_id)
425
- full_event.body[EventFieldType.ENDPOINT_ID] = endpoint_id
426
- full_event.body[EventFieldType.VERSIONED_MODEL] = versioned_model
427
- return full_event
428
-
429
-
430
252
  class ProcessBeforeParquet(mlrun.feature_store.steps.MapClass):
431
253
  def __init__(self, **kwargs):
432
254
  """
@@ -499,20 +321,27 @@ class ProcessEndpointEvent(mlrun.feature_store.steps.MapClass):
499
321
 
500
322
  def do(self, full_event):
501
323
  event = full_event.body
324
+ # Getting model version and function uri from event
325
+ # and use them for retrieving the endpoint_id
326
+ function_uri = full_event.body.get(EventFieldType.FUNCTION_URI)
327
+ if not is_not_none(function_uri, [EventFieldType.FUNCTION_URI]):
328
+ return None
329
+
330
+ model = full_event.body.get(EventFieldType.MODEL)
331
+ if not is_not_none(model, [EventFieldType.MODEL]):
332
+ return None
333
+
334
+ version = full_event.body.get(EventFieldType.VERSION)
335
+ versioned_model = f"{model}:{version}" if version else f"{model}:latest"
502
336
 
503
- versioned_model = event[EventFieldType.VERSIONED_MODEL]
337
+ full_event.body[EventFieldType.VERSIONED_MODEL] = versioned_model
504
338
  endpoint_id = event[EventFieldType.ENDPOINT_ID]
505
- function_uri = event[EventFieldType.FUNCTION_URI]
506
339
 
507
340
  # In case this process fails, resume state from existing record
508
- self.resume_state(endpoint_id)
509
-
510
- # If error key has been found in the current event,
511
- # increase the error counter by 1 and raise the error description
512
- error = event.get("error")
513
- if error: # TODO: delete this in ML-7456
514
- self.error_count[endpoint_id] += 1
515
- raise mlrun.errors.MLRunInvalidArgumentError(str(error))
341
+ self.resume_state(
342
+ endpoint_id,
343
+ full_event.body.get(EventFieldType.MODEL),
344
+ )
516
345
 
517
346
  # Validate event fields
518
347
  model_class = event.get("model_class") or event.get("class")
@@ -536,11 +365,6 @@ class ProcessEndpointEvent(mlrun.feature_store.steps.MapClass):
536
365
  # Set time for the first request of the current endpoint
537
366
  self.first_request[endpoint_id] = timestamp
538
367
 
539
- # Validate that the request time of the current event is later than the previous request time
540
- self._validate_last_request_timestamp(
541
- endpoint_id=endpoint_id, timestamp=timestamp
542
- )
543
-
544
368
  # Set time for the last reqeust of the current endpoint
545
369
  self.last_request[endpoint_id] = timestamp
546
370
 
@@ -610,6 +434,7 @@ class ProcessEndpointEvent(mlrun.feature_store.steps.MapClass):
610
434
  {
611
435
  EventFieldType.FUNCTION_URI: function_uri,
612
436
  EventFieldType.MODEL: versioned_model,
437
+ EventFieldType.ENDPOINT_NAME: event.get(EventFieldType.MODEL),
613
438
  EventFieldType.MODEL_CLASS: model_class,
614
439
  EventFieldType.TIMESTAMP: timestamp,
615
440
  EventFieldType.ENDPOINT_ID: endpoint_id,
@@ -636,33 +461,19 @@ class ProcessEndpointEvent(mlrun.feature_store.steps.MapClass):
636
461
  storey_event = storey.Event(body=events, key=endpoint_id)
637
462
  return storey_event
638
463
 
639
- def _validate_last_request_timestamp(self, endpoint_id: str, timestamp: str):
640
- """Validate that the request time of the current event is later than the previous request time that has
641
- already been processed.
642
-
643
- :param endpoint_id: The unique id of the model endpoint.
644
- :param timestamp: Event request time as a string.
645
-
646
- :raise MLRunPreconditionFailedError: If the request time of the current is later than the previous request time.
647
- """
648
-
649
- if (
650
- endpoint_id in self.last_request
651
- and self.last_request[endpoint_id] > timestamp
652
- ):
653
- logger.error(
654
- f"current event request time {timestamp} is earlier than the last request time "
655
- f"{self.last_request[endpoint_id]} - write to TSDB will be rejected"
656
- )
657
-
658
- def resume_state(self, endpoint_id):
464
+ def resume_state(self, endpoint_id, endpoint_name):
659
465
  # Make sure process is resumable, if process fails for any reason, be able to pick things up close to where we
660
466
  # left them
661
467
  if endpoint_id not in self.endpoints:
662
468
  logger.info("Trying to resume state", endpoint_id=endpoint_id)
663
- endpoint_record = mlrun.model_monitoring.helpers.get_endpoint_record(
664
- project=self.project,
665
- endpoint_id=endpoint_id,
469
+ endpoint_record = (
470
+ mlrun.db.get_run_db()
471
+ .get_model_endpoint(
472
+ project=self.project,
473
+ endpoint_id=endpoint_id,
474
+ name=endpoint_name,
475
+ )
476
+ .flat_dict()
666
477
  )
667
478
 
668
479
  # If model endpoint found, get first_request, last_request and error_count values
@@ -736,6 +547,7 @@ class MapFeatureNames(mlrun.feature_store.steps.MapClass):
736
547
  # and labels columns were not found in the current event
737
548
  self.feature_names = {}
738
549
  self.label_columns = {}
550
+ self.first_request = {}
739
551
 
740
552
  # Dictionary to manage the model endpoint types - important for the V3IO TSDB
741
553
  self.endpoint_type = {}
@@ -767,17 +579,22 @@ class MapFeatureNames(mlrun.feature_store.steps.MapClass):
767
579
  if isinstance(feature_value, int):
768
580
  feature_values[index] = float(feature_value)
769
581
 
582
+ attributes_to_update = {}
583
+ endpoint_record = None
770
584
  # Get feature names and label columns
771
585
  if endpoint_id not in self.feature_names:
772
- endpoint_record = mlrun.model_monitoring.helpers.get_endpoint_record(
773
- project=self.project,
774
- endpoint_id=endpoint_id,
586
+ endpoint_record = (
587
+ mlrun.db.get_run_db()
588
+ .get_model_endpoint(
589
+ project=self.project,
590
+ endpoint_id=endpoint_id,
591
+ name=event[EventFieldType.ENDPOINT_NAME],
592
+ )
593
+ .flat_dict()
775
594
  )
776
595
  feature_names = endpoint_record.get(EventFieldType.FEATURE_NAMES)
777
- feature_names = json.loads(feature_names) if feature_names else None
778
596
 
779
597
  label_columns = endpoint_record.get(EventFieldType.LABEL_NAMES)
780
- label_columns = json.loads(label_columns) if label_columns else None
781
598
 
782
599
  # If feature names were not found,
783
600
  # try to retrieve them from the previous events of the current process
@@ -795,13 +612,7 @@ class MapFeatureNames(mlrun.feature_store.steps.MapClass):
795
612
  ]
796
613
 
797
614
  # Update the endpoint record with the generated features
798
- update_endpoint_record(
799
- project=self.project,
800
- endpoint_id=endpoint_id,
801
- attributes={
802
- EventFieldType.FEATURE_NAMES: json.dumps(feature_names)
803
- },
804
- )
615
+ attributes_to_update[EventFieldType.FEATURE_NAMES] = feature_names
805
616
 
806
617
  if endpoint_type != EndpointType.ROUTER.value:
807
618
  update_monitoring_feature_set(
@@ -822,12 +633,7 @@ class MapFeatureNames(mlrun.feature_store.steps.MapClass):
822
633
  label_columns = [
823
634
  f"p{i}" for i, _ in enumerate(event[EventFieldType.PREDICTION])
824
635
  ]
825
-
826
- update_endpoint_record(
827
- project=self.project,
828
- endpoint_id=endpoint_id,
829
- attributes={EventFieldType.LABEL_NAMES: json.dumps(label_columns)},
830
- )
636
+ attributes_to_update[EventFieldType.LABEL_NAMES] = label_columns
831
637
  if endpoint_type != EndpointType.ROUTER.value:
832
638
  update_monitoring_feature_set(
833
639
  endpoint_record=endpoint_record,
@@ -848,6 +654,37 @@ class MapFeatureNames(mlrun.feature_store.steps.MapClass):
848
654
  # Update the endpoint type within the endpoint types dictionary
849
655
  self.endpoint_type[endpoint_id] = endpoint_type
850
656
 
657
+ # Update the first request time in the endpoint record
658
+ if endpoint_id not in self.first_request:
659
+ endpoint_record = endpoint_record or (
660
+ mlrun.db.get_run_db()
661
+ .get_model_endpoint(
662
+ project=self.project,
663
+ endpoint_id=endpoint_id,
664
+ name=event[EventFieldType.ENDPOINT_NAME],
665
+ )
666
+ .flat_dict()
667
+ )
668
+ if not endpoint_record.get(EventFieldType.FIRST_REQUEST):
669
+ attributes_to_update[EventFieldType.FIRST_REQUEST] = (
670
+ mlrun.utils.enrich_datetime_with_tz_info(
671
+ event[EventFieldType.FIRST_REQUEST]
672
+ )
673
+ )
674
+ self.first_request[endpoint_id] = True
675
+ if attributes_to_update:
676
+ logger.info(
677
+ "Updating endpoint record",
678
+ endpoint_id=endpoint_id,
679
+ attributes=attributes_to_update,
680
+ )
681
+ update_endpoint_record(
682
+ project=self.project,
683
+ endpoint_id=endpoint_id,
684
+ attributes=attributes_to_update,
685
+ endpoint_name=event[EventFieldType.ENDPOINT_NAME],
686
+ )
687
+
851
688
  # Add feature_name:value pairs along with a mapping dictionary of all of these pairs
852
689
  feature_names = self.feature_names[endpoint_id]
853
690
  self._map_dictionary_values(
@@ -898,30 +735,6 @@ class MapFeatureNames(mlrun.feature_store.steps.MapClass):
898
735
  event[mapping_dictionary][name] = value
899
736
 
900
737
 
901
- class UpdateEndpoint(mlrun.feature_store.steps.MapClass):
902
- def __init__(self, project: str, **kwargs):
903
- """
904
- Update the model endpoint record in the DB. Note that the event at this point includes metadata and stats about
905
- the average latency and the amount of predictions over time. This data will be used in the monitoring dashboards
906
- such as "Model Monitoring - Performance" which can be found in Grafana.
907
-
908
- :returns: Event as a dictionary (without any changes) for the next step (InferSchema).
909
- """
910
- super().__init__(**kwargs)
911
- self.project = project
912
-
913
- def do(self, event: dict):
914
- # Remove labels from the event
915
- event.pop(EventFieldType.LABELS)
916
-
917
- update_endpoint_record(
918
- project=self.project,
919
- endpoint_id=event.pop(EventFieldType.ENDPOINT_ID),
920
- attributes=event,
921
- )
922
- return event
923
-
924
-
925
738
  class InferSchema(mlrun.feature_store.steps.MapClass):
926
739
  def __init__(
927
740
  self,
@@ -966,14 +779,14 @@ class InferSchema(mlrun.feature_store.steps.MapClass):
966
779
  def update_endpoint_record(
967
780
  project: str,
968
781
  endpoint_id: str,
782
+ endpoint_name: str,
969
783
  attributes: dict,
970
784
  ):
971
- model_endpoint_store = mlrun.model_monitoring.get_store_object(
785
+ mlrun.db.get_run_db().patch_model_endpoint(
972
786
  project=project,
973
- )
974
-
975
- model_endpoint_store.update_model_endpoint(
976
- endpoint_id=endpoint_id, attributes=attributes
787
+ endpoint_id=endpoint_id,
788
+ attributes=attributes,
789
+ name=endpoint_name,
977
790
  )
978
791
 
979
792
 
@@ -21,7 +21,6 @@ import mlrun.common.schemas
21
21
  import mlrun.common.schemas.alert as alert_objects
22
22
  import mlrun.model_monitoring
23
23
  from mlrun.common.schemas.model_monitoring.constants import (
24
- EventFieldType,
25
24
  HistogramDataDriftApplicationConstants,
26
25
  MetricData,
27
26
  ResultData,
@@ -121,9 +120,6 @@ class ModelMonitoringWriter(StepToDict):
121
120
  notification_types=[NotificationKind.slack]
122
121
  )
123
122
 
124
- self._app_result_store = mlrun.model_monitoring.get_store_object(
125
- project=self.project, secret_provider=secret_provider
126
- )
127
123
  self._tsdb_connector = mlrun.model_monitoring.get_tsdb_connector(
128
124
  project=self.project, secret_provider=secret_provider
129
125
  )
@@ -266,14 +262,9 @@ class ModelMonitoringWriter(StepToDict):
266
262
  == ResultStatusApp.potential_detection.value
267
263
  )
268
264
  ):
269
- endpoint_id = event[WriterEvent.ENDPOINT_ID]
270
- endpoint_record = self._endpoints_records.setdefault(
271
- endpoint_id,
272
- self._app_result_store.get_model_endpoint(endpoint_id=endpoint_id),
273
- )
274
265
  event_value = {
275
266
  "app_name": event[WriterEvent.APPLICATION_NAME],
276
- "model": endpoint_record.get(EventFieldType.MODEL),
267
+ "model": event[WriterEvent.ENDPOINT_NAME],
277
268
  "model_endpoint_id": event[WriterEvent.ENDPOINT_ID],
278
269
  "result_name": event[ResultData.RESULT_NAME],
279
270
  "result_value": event[ResultData.RESULT_VALUE],
@@ -39,7 +39,7 @@ from mlrun.utils import (
39
39
 
40
40
  from ..common.helpers import parse_versioned_object_uri
41
41
  from ..config import config
42
- from ..run import _run_pipeline, wait_for_pipeline_completion
42
+ from ..run import _run_pipeline, retry_pipeline, wait_for_pipeline_completion
43
43
  from ..runtimes.pod import AutoMountType
44
44
 
45
45
 
@@ -421,6 +421,13 @@ class _PipelineRunStatus:
421
421
  self._state = returned_state
422
422
  return self._state
423
423
 
424
+ def retry(self) -> str:
425
+ run_id = self._engine.retry(
426
+ self,
427
+ project=self.project,
428
+ )
429
+ return run_id
430
+
424
431
  def __str__(self):
425
432
  return str(self.run_id)
426
433
 
@@ -440,6 +447,17 @@ class _PipelineRunner(abc.ABC):
440
447
  f"Save operation not supported in {cls.engine} pipeline engine"
441
448
  )
442
449
 
450
+ @classmethod
451
+ @abc.abstractmethod
452
+ def retry(
453
+ cls,
454
+ run: "_PipelineRunStatus",
455
+ project: typing.Optional["mlrun.projects.MlrunProject"] = None,
456
+ ) -> str:
457
+ raise NotImplementedError(
458
+ f"Retry operation not supported in {cls.engine} pipeline engine"
459
+ )
460
+
443
461
  @classmethod
444
462
  @abc.abstractmethod
445
463
  def run(
@@ -635,6 +653,24 @@ class _KFPRunner(_PipelineRunner):
635
653
  pipeline_context.clear()
636
654
  return _PipelineRunStatus(run_id, cls, project=project, workflow=workflow_spec)
637
655
 
656
+ @classmethod
657
+ def retry(
658
+ cls,
659
+ run: "_PipelineRunStatus",
660
+ project: typing.Optional["mlrun.projects.MlrunProject"] = None,
661
+ ) -> str:
662
+ project_name = project.metadata.name if project else ""
663
+ logger.info(
664
+ "Retrying pipeline",
665
+ run_id=run.run_id,
666
+ project=project_name,
667
+ )
668
+ run_id = retry_pipeline(
669
+ run.run_id,
670
+ project=project_name,
671
+ )
672
+ return run_id
673
+
638
674
  @staticmethod
639
675
  def wait_for_completion(
640
676
  run: "_PipelineRunStatus",