mlrun 1.8.0rc10__py3-none-any.whl → 1.8.0rc13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (40) hide show
  1. mlrun/artifacts/document.py +32 -6
  2. mlrun/common/constants.py +1 -0
  3. mlrun/common/formatters/artifact.py +1 -1
  4. mlrun/common/schemas/__init__.py +2 -0
  5. mlrun/common/schemas/model_monitoring/__init__.py +1 -0
  6. mlrun/common/schemas/model_monitoring/constants.py +6 -0
  7. mlrun/common/schemas/model_monitoring/model_endpoints.py +35 -0
  8. mlrun/common/schemas/partition.py +23 -18
  9. mlrun/datastore/vectorstore.py +69 -26
  10. mlrun/db/base.py +14 -0
  11. mlrun/db/httpdb.py +48 -1
  12. mlrun/db/nopdb.py +13 -0
  13. mlrun/execution.py +43 -11
  14. mlrun/feature_store/steps.py +1 -1
  15. mlrun/model_monitoring/api.py +26 -19
  16. mlrun/model_monitoring/applications/_application_steps.py +1 -1
  17. mlrun/model_monitoring/applications/base.py +44 -7
  18. mlrun/model_monitoring/applications/context.py +94 -71
  19. mlrun/projects/pipelines.py +6 -3
  20. mlrun/projects/project.py +95 -17
  21. mlrun/runtimes/nuclio/function.py +2 -1
  22. mlrun/runtimes/nuclio/serving.py +33 -5
  23. mlrun/serving/__init__.py +8 -0
  24. mlrun/serving/merger.py +1 -1
  25. mlrun/serving/remote.py +17 -5
  26. mlrun/serving/routers.py +36 -87
  27. mlrun/serving/server.py +6 -2
  28. mlrun/serving/states.py +162 -13
  29. mlrun/serving/v2_serving.py +39 -82
  30. mlrun/utils/helpers.py +6 -0
  31. mlrun/utils/notifications/notification/base.py +1 -1
  32. mlrun/utils/notifications/notification/webhook.py +13 -12
  33. mlrun/utils/notifications/notification_pusher.py +18 -23
  34. mlrun/utils/version/version.json +2 -2
  35. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/METADATA +10 -10
  36. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/RECORD +40 -40
  37. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/LICENSE +0 -0
  38. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/WHEEL +0 -0
  39. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/entry_points.txt +0 -0
  40. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/top_level.txt +0 -0
mlrun/serving/states.py CHANGED
@@ -25,11 +25,12 @@ import pathlib
25
25
  import traceback
26
26
  from copy import copy, deepcopy
27
27
  from inspect import getfullargspec, signature
28
- from typing import Any, Optional, Union
28
+ from typing import Any, Optional, Union, cast
29
29
 
30
30
  import storey.utils
31
31
 
32
32
  import mlrun
33
+ import mlrun.common.schemas as schemas
33
34
 
34
35
  from ..config import config
35
36
  from ..datastore import get_stream_pusher
@@ -81,22 +82,28 @@ _task_step_fields = [
81
82
  "responder",
82
83
  "input_path",
83
84
  "result_path",
85
+ "model_endpoint_creation_strategy",
86
+ "endpoint_type",
84
87
  ]
85
88
 
86
89
 
87
90
  MAX_ALLOWED_STEPS = 4500
88
91
 
89
92
 
90
- def new_model_endpoint(class_name, model_path, handler=None, **class_args):
91
- class_args = deepcopy(class_args)
92
- class_args["model_path"] = model_path
93
- return TaskStep(class_name, class_args, handler=handler)
94
-
95
-
96
- def new_remote_endpoint(url, **class_args):
93
+ def new_remote_endpoint(
94
+ url: str,
95
+ creation_strategy: schemas.ModelEndpointCreationStrategy,
96
+ endpoint_type: schemas.EndpointType,
97
+ **class_args,
98
+ ):
97
99
  class_args = deepcopy(class_args)
98
100
  class_args["url"] = url
99
- return TaskStep("$remote", class_args)
101
+ return TaskStep(
102
+ "$remote",
103
+ class_args=class_args,
104
+ model_endpoint_creation_strategy=creation_strategy,
105
+ endpoint_type=endpoint_type,
106
+ )
100
107
 
101
108
 
102
109
  class BaseStep(ModelObj):
@@ -419,6 +426,10 @@ class TaskStep(BaseStep):
419
426
  responder: Optional[bool] = None,
420
427
  input_path: Optional[str] = None,
421
428
  result_path: Optional[str] = None,
429
+ model_endpoint_creation_strategy: Optional[
430
+ schemas.ModelEndpointCreationStrategy
431
+ ] = schemas.ModelEndpointCreationStrategy.INPLACE,
432
+ endpoint_type: Optional[schemas.EndpointType] = schemas.EndpointType.NODE_EP,
422
433
  ):
423
434
  super().__init__(name, after)
424
435
  self.class_name = class_name
@@ -438,6 +449,8 @@ class TaskStep(BaseStep):
438
449
  self.on_error = None
439
450
  self._inject_context = False
440
451
  self._call_with_event = False
452
+ self.model_endpoint_creation_strategy = model_endpoint_creation_strategy
453
+ self.endpoint_type = endpoint_type
441
454
 
442
455
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
443
456
  self.context = context
@@ -554,7 +567,11 @@ class TaskStep(BaseStep):
554
567
 
555
568
  def _post_init(self, mode="sync"):
556
569
  if self._object and hasattr(self._object, "post_init"):
557
- self._object.post_init(mode)
570
+ self._object.post_init(
571
+ mode,
572
+ creation_strategy=self.model_endpoint_creation_strategy,
573
+ endpoint_type=self.endpoint_type,
574
+ )
558
575
  if hasattr(self._object, "model_endpoint_uid"):
559
576
  self.endpoint_uid = self._object.model_endpoint_uid
560
577
  if hasattr(self._object, "name"):
@@ -705,6 +722,7 @@ class RouterStep(TaskStep):
705
722
  )
706
723
  self._routes: ObjectDict = None
707
724
  self.routes = routes
725
+ self.endpoint_type = schemas.EndpointType.ROUTER
708
726
 
709
727
  def get_children(self):
710
728
  """get child steps (routes)"""
@@ -726,6 +744,7 @@ class RouterStep(TaskStep):
726
744
  class_name=None,
727
745
  handler=None,
728
746
  function=None,
747
+ creation_strategy: schemas.ModelEndpointCreationStrategy = schemas.ModelEndpointCreationStrategy.INPLACE,
729
748
  **class_args,
730
749
  ):
731
750
  """add child route step or class to the router
@@ -736,12 +755,31 @@ class RouterStep(TaskStep):
736
755
  :param class_args: class init arguments
737
756
  :param handler: class handler to invoke on run/event
738
757
  :param function: function this step should run in
758
+ :param creation_strategy: Strategy for creating or updating the model endpoint:
759
+ * **overwrite**:
760
+ 1. If model endpoints with the same name exist, delete the `latest` one.
761
+ 2. Create a new model endpoint entry and set it as `latest`.
762
+ * **inplace** (default):
763
+ 1. If model endpoints with the same name exist, update the `latest` entry.
764
+ 2. Otherwise, create a new entry.
765
+ * **archive**:
766
+ 1. If model endpoints with the same name exist, preserve them.
767
+ 2. Create a new model endpoint with the same name and set it to `latest`.
768
+
739
769
  """
740
770
 
741
771
  if not route and not class_name and not handler:
742
772
  raise MLRunInvalidArgumentError("route or class_name must be specified")
743
773
  if not route:
744
- route = TaskStep(class_name, class_args, handler=handler)
774
+ route = TaskStep(
775
+ class_name,
776
+ class_args,
777
+ handler=handler,
778
+ model_endpoint_creation_strategy=creation_strategy,
779
+ endpoint_type=schemas.EndpointType.LEAF_EP
780
+ if self.class_name and "serving.VotingEnsemble" in self.class_name
781
+ else schemas.EndpointType.NODE_EP,
782
+ )
745
783
  route.function = function or route.function
746
784
 
747
785
  if len(self._routes) >= MAX_ALLOWED_STEPS:
@@ -805,6 +843,106 @@ class RouterStep(TaskStep):
805
843
  )
806
844
 
807
845
 
846
+ class Model(storey.ParallelExecutionRunnable):
847
+ def load(self) -> None:
848
+ """Override to load model if needed."""
849
+ pass
850
+
851
+ def init(self):
852
+ self.load()
853
+
854
+ def predict(self, body: Any) -> Any:
855
+ """Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead."""
856
+ return body
857
+
858
+ async def predict_async(self, body: Any) -> Any:
859
+ """Override to implement prediction logic if the logic requires asyncio."""
860
+ return body
861
+
862
+ def run(self, body: Any, path: str) -> Any:
863
+ return self.predict(body)
864
+
865
+ async def run_async(self, body: Any, path: str) -> Any:
866
+ return self.predict(body)
867
+
868
+
869
+ class ModelSelector:
870
+ """Used to select which models to run on each event."""
871
+
872
+ def select(
873
+ self, event, available_models: list[Model]
874
+ ) -> Union[list[str], list[Model]]:
875
+ """
876
+ Given an event, returns a list of model names or a list of model objects to run on the event.
877
+ If None is returned, all models will be run.
878
+
879
+ :param event: The full event
880
+ :param available_models: List of available models
881
+ """
882
+ pass
883
+
884
+
885
+ class ModelRunner(storey.ParallelExecution):
886
+ """
887
+ Runs multiple Models on each event. See ModelRunnerStep.
888
+
889
+ :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
890
+ event. Optional. If not passed, all models will be run.
891
+ """
892
+
893
+ def __init__(self, *args, model_selector: Optional[ModelSelector] = None, **kwargs):
894
+ super().__init__(*args, **kwargs)
895
+ self.model_selector = model_selector or ModelSelector()
896
+
897
+ def select_runnables(self, event):
898
+ models = cast(list[Model], self.runnables)
899
+ return self.model_selector.select(event, models)
900
+
901
+
902
+ class ModelRunnerStep(TaskStep):
903
+ """
904
+ Runs multiple Models on each event.
905
+
906
+ example::
907
+
908
+ model_runner_step = ModelRunnerStep(name="my_model_runner")
909
+ model_runner_step.add_model(MyModel(name="my_model"))
910
+ graph.to(model_runner_step)
911
+
912
+ :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
913
+ event. Optional. If not passed, all models will be run.
914
+ """
915
+
916
+ kind = "model_runner"
917
+
918
+ def __init__(
919
+ self,
920
+ *args,
921
+ model_selector: Optional[Union[str, ModelSelector]] = None,
922
+ **kwargs,
923
+ ):
924
+ self._models = []
925
+ super().__init__(
926
+ *args,
927
+ class_name="mlrun.serving.ModelRunner",
928
+ class_args=dict(runnables=self._models, model_selector=model_selector),
929
+ **kwargs,
930
+ )
931
+
932
+ def add_model(self, model: Model) -> None:
933
+ """Add a Model to this ModelRunner."""
934
+ self._models.append(model)
935
+
936
+ def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
937
+ model_selector = self.class_args.get("model_selector")
938
+ if isinstance(model_selector, str):
939
+ model_selector = get_class(model_selector, namespace)()
940
+ self._async_object = ModelRunner(
941
+ self.class_args.get("runnables"),
942
+ model_selector=model_selector,
943
+ )
944
+
945
+
808
946
  class QueueStep(BaseStep):
809
947
  """queue step, implement an async queue or represent a stream"""
810
948
 
@@ -1344,8 +1482,9 @@ class FlowStep(BaseStep):
1344
1482
 
1345
1483
  if self._controller:
1346
1484
  if hasattr(self._controller, "terminate"):
1347
- self._controller.terminate()
1348
- return self._controller.await_termination()
1485
+ return self._controller.terminate(wait=True)
1486
+ else:
1487
+ return self._controller.await_termination()
1349
1488
 
1350
1489
  def plot(self, filename=None, format=None, source=None, targets=None, **kw):
1351
1490
  """plot/save graph using graphviz
@@ -1433,6 +1572,7 @@ classes_map = {
1433
1572
  "queue": QueueStep,
1434
1573
  "error_step": ErrorStep,
1435
1574
  "monitoring_application": MonitoringApplicationStep,
1575
+ "model_runner": ModelRunnerStep,
1436
1576
  }
1437
1577
 
1438
1578
 
@@ -1572,6 +1712,10 @@ def params_to_step(
1572
1712
  input_path: Optional[str] = None,
1573
1713
  result_path: Optional[str] = None,
1574
1714
  class_args=None,
1715
+ model_endpoint_creation_strategy: Optional[
1716
+ schemas.ModelEndpointCreationStrategy
1717
+ ] = None,
1718
+ endpoint_type: Optional[schemas.EndpointType] = None,
1575
1719
  ):
1576
1720
  """return step object from provided params or classes/objects"""
1577
1721
 
@@ -1587,6 +1731,9 @@ def params_to_step(
1587
1731
  step.full_event = full_event or step.full_event
1588
1732
  step.input_path = input_path or step.input_path
1589
1733
  step.result_path = result_path or step.result_path
1734
+ if kind == StepKinds.task:
1735
+ step.model_endpoint_creation_strategy = model_endpoint_creation_strategy
1736
+ step.endpoint_type = endpoint_type
1590
1737
 
1591
1738
  elif class_name and class_name in queue_class_names:
1592
1739
  if "path" not in class_args:
@@ -1627,6 +1774,8 @@ def params_to_step(
1627
1774
  full_event=full_event,
1628
1775
  input_path=input_path,
1629
1776
  result_path=result_path,
1777
+ model_endpoint_creation_strategy=model_endpoint_creation_strategy,
1778
+ endpoint_type=endpoint_type,
1630
1779
  )
1631
1780
  else:
1632
1781
  raise MLRunInvalidArgumentError("class_name or handler must be provided")
@@ -23,7 +23,6 @@ import mlrun.common.schemas.model_monitoring
23
23
  import mlrun.model_monitoring
24
24
  from mlrun.utils import logger, now_date
25
25
 
26
- from ..common.schemas.model_monitoring import ModelEndpointSchema
27
26
  from .server import GraphServer
28
27
  from .utils import StepToDict, _extract_input_data, _update_result_body
29
28
 
@@ -130,7 +129,7 @@ class V2ModelServer(StepToDict):
130
129
  self.ready = True
131
130
  self.context.logger.info(f"model {self.name} was loaded")
132
131
 
133
- def post_init(self, mode="sync"):
132
+ def post_init(self, mode="sync", **kwargs):
134
133
  """sync/async model loading, for internal use"""
135
134
  if not self.ready:
136
135
  if mode == "async":
@@ -149,7 +148,10 @@ class V2ModelServer(StepToDict):
149
148
 
150
149
  if not self.context.is_mock or self.context.monitoring_mock:
151
150
  self.model_endpoint_uid = _init_endpoint_record(
152
- graph_server=server, model=self
151
+ graph_server=server,
152
+ model=self,
153
+ creation_strategy=kwargs.get("creation_strategy"),
154
+ endpoint_type=kwargs.get("endpoint_type"),
153
155
  )
154
156
  self._model_logger = (
155
157
  _ModelLogPusher(self, self.context)
@@ -554,7 +556,10 @@ class _ModelLogPusher:
554
556
 
555
557
 
556
558
  def _init_endpoint_record(
557
- graph_server: GraphServer, model: V2ModelServer
559
+ graph_server: GraphServer,
560
+ model: V2ModelServer,
561
+ creation_strategy: mlrun.common.schemas.ModelEndpointCreationStrategy,
562
+ endpoint_type: mlrun.common.schemas.EndpointType,
558
563
  ) -> Union[str, None]:
559
564
  """
560
565
  Initialize model endpoint record and write it into the DB. In general, this method retrieve the unique model
@@ -564,6 +569,17 @@ def _init_endpoint_record(
564
569
  :param graph_server: A GraphServer object which will be used for getting the function uri.
565
570
  :param model: Base model serving class (v2). It contains important details for the model endpoint record
566
571
  such as model name, model path, and model version.
572
+ :param creation_strategy: Strategy for creating or updating the model endpoint:
573
+ * **overwrite**:
574
+ 1. If model endpoints with the same name exist, delete the `latest` one.
575
+ 2. Create a new model endpoint entry and set it as `latest`.
576
+ * **inplace** (default):
577
+ 1. If model endpoints with the same name exist, update the `latest` entry.
578
+ 2. Otherwise, create a new entry.
579
+ * **archive**:
580
+ 1. If model endpoints with the same name exist, preserve them.
581
+ 2. Create a new model endpoint with the same name and set it to `latest`.
582
+ :param endpoint_type model endpoint type
567
583
 
568
584
  :return: Model endpoint unique ID.
569
585
  """
@@ -583,51 +599,30 @@ def _init_endpoint_record(
583
599
  model_uid = None
584
600
  model_tag = None
585
601
  model_labels = {}
586
- try:
587
- model_ep = mlrun.get_run_db().get_model_endpoint(
588
- project=graph_server.project,
589
- name=model.name,
590
- function_name=graph_server.function_name,
591
- function_tag=graph_server.function_tag or "latest",
592
- )
593
- except mlrun.errors.MLRunNotFoundError:
594
- model_ep = None
595
- except mlrun.errors.MLRunBadRequestError as err:
596
- logger.info(
597
- "Cannot get the model endpoints store", err=mlrun.errors.err_to_str(err)
598
- )
599
- return
600
-
601
- function = mlrun.get_run_db().get_function(
602
- name=graph_server.function_name,
602
+ logger.info(
603
+ "Creating Or Updating a new model endpoint record",
604
+ name=model.name,
603
605
  project=graph_server.project,
604
- tag=graph_server.function_tag or "latest",
606
+ function_name=graph_server.function_name,
607
+ function_tag=graph_server.function_tag or "latest",
608
+ model_name=model_name,
609
+ model_tag=model_tag,
610
+ model_db_key=model_db_key,
611
+ model_uid=model_uid,
612
+ model_class=model.__class__.__name__,
613
+ creation_strategy=creation_strategy,
614
+ endpoint_type=endpoint_type,
605
615
  )
606
- function_uid = function.get("metadata", {}).get("uid")
607
- if not model_ep and model.context.server.track_models:
608
- logger.info(
609
- "Creating a new model endpoint record",
610
- name=model.name,
611
- project=graph_server.project,
612
- function_name=graph_server.function_name,
613
- function_tag=graph_server.function_tag or "latest",
614
- function_uid=function_uid,
615
- model_name=model_name,
616
- model_tag=model_tag,
617
- model_db_key=model_db_key,
618
- model_uid=model_uid,
619
- model_class=model.__class__.__name__,
620
- )
616
+ try:
621
617
  model_ep = mlrun.common.schemas.ModelEndpoint(
622
618
  metadata=mlrun.common.schemas.ModelEndpointMetadata(
623
619
  project=graph_server.project,
624
620
  labels=model_labels,
625
621
  name=model.name,
626
- endpoint_type=mlrun.common.schemas.model_monitoring.EndpointType.NODE_EP,
622
+ endpoint_type=endpoint_type,
627
623
  ),
628
624
  spec=mlrun.common.schemas.ModelEndpointSpec(
629
625
  function_name=graph_server.function_name,
630
- function_uid=function_uid,
631
626
  function_tag=graph_server.function_tag or "latest",
632
627
  model_name=model_name,
633
628
  model_db_key=model_db_key,
@@ -642,49 +637,11 @@ def _init_endpoint_record(
642
637
  ),
643
638
  )
644
639
  db = mlrun.get_run_db()
645
- model_ep = db.create_model_endpoint(model_endpoint=model_ep)
646
-
647
- elif model_ep:
648
- attributes = {}
649
- if function_uid != model_ep.spec.function_uid:
650
- attributes[ModelEndpointSchema.FUNCTION_UID] = function_uid
651
- if model_name != model_ep.spec.model_name:
652
- attributes[ModelEndpointSchema.MODEL_NAME] = model_name
653
- if model_uid != model_ep.spec.model_uid:
654
- attributes[ModelEndpointSchema.MODEL_UID] = model_uid
655
- if model_tag != model_ep.spec.model_tag:
656
- attributes[ModelEndpointSchema.MODEL_TAG] = model_tag
657
- if model_db_key != model_ep.spec.model_db_key:
658
- attributes[ModelEndpointSchema.MODEL_DB_KEY] = model_db_key
659
- if model_labels != model_ep.metadata.labels:
660
- attributes[ModelEndpointSchema.LABELS] = model_labels
661
- if model.__class__.__name__ != model_ep.spec.model_class:
662
- attributes[ModelEndpointSchema.MODEL_CLASS] = model.__class__.__name__
663
- if (
664
- model_ep.status.monitoring_mode
665
- == mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
666
- ) != model.context.server.track_models:
667
- attributes[ModelEndpointSchema.MONITORING_MODE] = (
668
- mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
669
- if model.context.server.track_models
670
- else mlrun.common.schemas.model_monitoring.ModelMonitoringMode.disabled
671
- )
672
- if attributes:
673
- logger.info(
674
- "Updating model endpoint attributes",
675
- attributes=attributes,
676
- project=model_ep.metadata.project,
677
- name=model_ep.metadata.name,
678
- function_name=model_ep.spec.function_name,
679
- )
680
- db = mlrun.get_run_db()
681
- model_ep = db.patch_model_endpoint(
682
- project=model_ep.metadata.project,
683
- name=model_ep.metadata.name,
684
- endpoint_id=model_ep.metadata.uid,
685
- attributes=attributes,
686
- )
687
- else:
640
+ model_ep = db.create_model_endpoint(
641
+ model_endpoint=model_ep, creation_strategy=creation_strategy
642
+ )
643
+ except mlrun.errors.MLRunBadRequestError as e:
644
+ logger.info("Failed to create model endpoint record", error=e)
688
645
  return None
689
646
 
690
647
  return model_ep.metadata.uid
mlrun/utils/helpers.py CHANGED
@@ -111,15 +111,21 @@ def get_artifact_target(item: dict, project=None):
111
111
  project_str = project or item["metadata"].get("project")
112
112
  tree = item["metadata"].get("tree")
113
113
  tag = item["metadata"].get("tag")
114
+ iter = item["metadata"].get("iter")
114
115
  kind = item.get("kind")
116
+ uid = item["metadata"].get("uid")
115
117
 
116
118
  if kind in {"dataset", "model", "artifact"} and db_key:
117
119
  target = (
118
120
  f"{DB_SCHEMA}://{StorePrefix.kind_to_prefix(kind)}/{project_str}/{db_key}"
119
121
  )
122
+ if iter:
123
+ target = f"{target}#{iter}"
120
124
  target += f":{tag}" if tag else ":latest"
121
125
  if tree:
122
126
  target += f"@{tree}"
127
+ if uid:
128
+ target += f"^{uid}"
123
129
  return target
124
130
 
125
131
  return item["spec"].get("target_path")
@@ -57,7 +57,7 @@ class NotificationBase:
57
57
  typing.Union[mlrun.common.schemas.NotificationSeverity, str]
58
58
  ] = mlrun.common.schemas.NotificationSeverity.INFO,
59
59
  runs: typing.Optional[typing.Union[mlrun.lists.RunList, list]] = None,
60
- custom_html: typing.Optional[typing.Optional[str]] = None,
60
+ custom_html: typing.Optional[str] = None,
61
61
  alert: typing.Optional[mlrun.common.schemas.AlertConfig] = None,
62
62
  event_data: typing.Optional[mlrun.common.schemas.Event] = None,
63
63
  ):
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import re
15
16
  import typing
16
17
 
17
18
  import aiohttp
@@ -93,7 +94,6 @@ class WebhookNotification(NotificationBase):
93
94
 
94
95
  @staticmethod
95
96
  def _serialize_runs_in_request_body(override_body, runs):
96
- str_parsed_runs = ""
97
97
  runs = runs or []
98
98
 
99
99
  def parse_runs():
@@ -105,22 +105,23 @@ class WebhookNotification(NotificationBase):
105
105
  parsed_run = {
106
106
  "project": run["metadata"]["project"],
107
107
  "name": run["metadata"]["name"],
108
- "host": run["metadata"]["labels"]["host"],
109
108
  "status": {"state": run["status"]["state"]},
110
109
  }
111
- if run["status"].get("error", None):
112
- parsed_run["status"]["error"] = run["status"]["error"]
113
- elif run["status"].get("results", None):
114
- parsed_run["status"]["results"] = run["status"]["results"]
110
+ if host := run["metadata"].get("labels", {}).get("host", ""):
111
+ parsed_run["host"] = host
112
+ if error := run["status"].get("error"):
113
+ parsed_run["status"]["error"] = error
114
+ elif results := run["status"].get("results"):
115
+ parsed_run["status"]["results"] = results
115
116
  parsed_runs.append(parsed_run)
116
117
  return str(parsed_runs)
117
118
 
118
119
  if isinstance(override_body, dict):
119
120
  for key, value in override_body.items():
120
- if "{{ runs }}" or "{{runs}}" in value:
121
- if not str_parsed_runs:
122
- str_parsed_runs = parse_runs()
123
- override_body[key] = value.replace(
124
- "{{ runs }}", str_parsed_runs
125
- ).replace("{{runs}}", str_parsed_runs)
121
+ if re.search(r"{{\s*runs\s*}}", value):
122
+ str_parsed_runs = parse_runs()
123
+ override_body[key] = re.sub(
124
+ r"{{\s*runs\s*}}", str_parsed_runs, value
125
+ )
126
+
126
127
  return override_body
@@ -14,7 +14,6 @@
14
14
 
15
15
  import asyncio
16
16
  import datetime
17
- import os
18
17
  import re
19
18
  import traceback
20
19
  import typing
@@ -97,6 +96,7 @@ class NotificationPusher(_NotificationPusherBase):
97
96
  "completed": "{resource} completed",
98
97
  "error": "{resource} failed",
99
98
  "aborted": "{resource} aborted",
99
+ "running": "{resource} started",
100
100
  }
101
101
 
102
102
  def __init__(
@@ -285,6 +285,7 @@ class NotificationPusher(_NotificationPusherBase):
285
285
 
286
286
  message = (
287
287
  self.messages.get(run.state(), "").format(resource=resource)
288
+ + f" in project {run.metadata.project}"
288
289
  + custom_message
289
290
  )
290
291
 
@@ -303,6 +304,7 @@ class NotificationPusher(_NotificationPusherBase):
303
304
  message, severity, runs = self._prepare_notification_args(
304
305
  run, notification_object
305
306
  )
307
+
306
308
  logger.debug(
307
309
  "Pushing sync notification",
308
310
  notification=sanitize_notification(notification_object.to_dict()),
@@ -313,6 +315,7 @@ class NotificationPusher(_NotificationPusherBase):
313
315
  "project": run.metadata.project,
314
316
  "notification": notification_object,
315
317
  "status": mlrun.common.schemas.NotificationStatus.SENT,
318
+ "run_state": run.state(),
316
319
  }
317
320
  try:
318
321
  notification.push(message, severity, runs)
@@ -351,6 +354,7 @@ class NotificationPusher(_NotificationPusherBase):
351
354
  message, severity, runs = self._prepare_notification_args(
352
355
  run, notification_object
353
356
  )
357
+
354
358
  logger.debug(
355
359
  "Pushing async notification",
356
360
  notification=sanitize_notification(notification_object.to_dict()),
@@ -360,6 +364,7 @@ class NotificationPusher(_NotificationPusherBase):
360
364
  "run_uid": run.metadata.uid,
361
365
  "project": run.metadata.project,
362
366
  "notification": notification_object,
367
+ "run_state": run.state(),
363
368
  "status": mlrun.common.schemas.NotificationStatus.SENT,
364
369
  }
365
370
  try:
@@ -397,10 +402,20 @@ class NotificationPusher(_NotificationPusherBase):
397
402
  run_uid: str,
398
403
  project: str,
399
404
  notification: mlrun.model.Notification,
405
+ run_state: runtimes_constants.RunStates,
400
406
  status: typing.Optional[str] = None,
401
407
  sent_time: typing.Optional[datetime.datetime] = None,
402
408
  reason: typing.Optional[str] = None,
403
409
  ):
410
+ if run_state not in runtimes_constants.RunStates.terminal_states():
411
+ # we want to update the notification status only if the run is in a terminal state for BC
412
+ logger.debug(
413
+ "Skip updating notification status - run not in terminal state",
414
+ run_uid=run_uid,
415
+ state=run_state,
416
+ )
417
+ return
418
+
404
419
  db = mlrun.get_run_db()
405
420
  notification.status = status or notification.status
406
421
  notification.sent_time = sent_time or notification.sent_time
@@ -664,30 +679,10 @@ class CustomNotificationPusher(_NotificationPusherBase):
664
679
  def push_pipeline_start_message(
665
680
  self,
666
681
  project: str,
667
- commit_id: typing.Optional[str] = None,
668
682
  pipeline_id: typing.Optional[str] = None,
669
- has_workflow_url: bool = False,
670
683
  ):
671
- message = f"Workflow started in project {project}"
672
- if pipeline_id:
673
- message += f" id={pipeline_id}"
674
- commit_id = (
675
- commit_id or os.environ.get("GITHUB_SHA") or os.environ.get("CI_COMMIT_SHA")
676
- )
677
- if commit_id:
678
- message += f", commit={commit_id}"
679
- if has_workflow_url:
680
- url = mlrun.utils.helpers.get_workflow_url(project, pipeline_id)
681
- else:
682
- url = mlrun.utils.helpers.get_ui_url(project)
683
- html = ""
684
- if url:
685
- html = (
686
- message
687
- + f'<div><a href="{url}" target="_blank">click here to view progress</a></div>'
688
- )
689
- message = message + f", check progress in {url}"
690
- self.push(message, "info", custom_html=html)
684
+ db = mlrun.get_run_db()
685
+ db.push_run_notifications(pipeline_id, project)
691
686
 
692
687
  def push_pipeline_run_results(
693
688
  self,
@@ -1,4 +1,4 @@
1
1
  {
2
- "git_commit": "9e2d1e195daed90072d1cd33a5eee339577dc35a",
3
- "version": "1.8.0-rc10"
2
+ "git_commit": "6b23c9dc1a77d33bf56dd6e91623278630bb46c4",
3
+ "version": "1.8.0-rc13"
4
4
  }