mlrun 1.7.0rc12__py3-none-any.whl → 1.7.0rc14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

@@ -49,7 +49,7 @@ class Event(pydantic.BaseModel):
49
49
  kind: EventKind
50
50
  timestamp: Union[str, datetime] = None # occurrence time
51
51
  entity: EventEntity
52
- value: Optional[Union[float, str]] = None
52
+ value_dict: Optional[dict] = pydantic.Field(default_factory=dict)
53
53
 
54
54
  def is_valid(self):
55
55
  return self.entity.kind in _event_kind_entity_map[self.kind]
@@ -110,6 +110,7 @@ class ProjectSummary(pydantic.BaseModel):
110
110
  files_count: int
111
111
  feature_sets_count: int
112
112
  models_count: int
113
+ runs_completed_recent_count: int
113
114
  runs_failed_recent_count: int
114
115
  runs_running_count: int
115
116
  schedules_count: int
mlrun/config.py CHANGED
@@ -188,6 +188,7 @@ default_config = {
188
188
  "background_tasks": {
189
189
  # enabled / disabled
190
190
  "timeout_mode": "enabled",
191
+ "function_deletion_batch_size": 10,
191
192
  # timeout in seconds to wait for background task to be updated / finished by the worker responsible for the task
192
193
  "default_timeouts": {
193
194
  "operations": {
@@ -196,6 +197,7 @@ default_config = {
196
197
  "run_abortion": "600",
197
198
  "abort_grace_period": "10",
198
199
  "delete_project": "900",
200
+ "delete_function": "900",
199
201
  },
200
202
  "runtimes": {"dask": "600"},
201
203
  },
@@ -552,6 +554,7 @@ default_config = {
552
554
  "nosql": "v3io:///projects/{project}/FeatureStore/{name}/nosql",
553
555
  # "authority" is optional and generalizes [userinfo "@"] host [":" port]
554
556
  "redisnosql": "redis://{authority}/projects/{project}/FeatureStore/{name}/nosql",
557
+ "dsnosql": "ds://{ds_profile_name}/projects/{project}/FeatureStore/{name}/nosql",
555
558
  },
556
559
  "default_targets": "parquet,nosql",
557
560
  "default_job_image": "mlrun/mlrun",
@@ -691,6 +694,10 @@ default_config = {
691
694
  # supported modes: "enabled", "disabled".
692
695
  "mode": "disabled"
693
696
  },
697
+ "auth_with_client_id": {
698
+ "enabled": False,
699
+ "request_timeout": 5,
700
+ },
694
701
  }
695
702
 
696
703
  _is_running_as_api = None
@@ -1395,7 +1402,11 @@ def read_env(env=None, prefix=env_prefix):
1395
1402
  log_formatter = mlrun.utils.create_formatter_instance(
1396
1403
  mlrun.utils.FormatterKinds(log_formatter_name)
1397
1404
  )
1398
- mlrun.utils.logger.get_handler("default").setFormatter(log_formatter)
1405
+ current_handler = mlrun.utils.logger.get_handler("default")
1406
+ current_formatter_name = current_handler.formatter.__class__.__name__
1407
+ desired_formatter_name = log_formatter.__class__.__name__
1408
+ if current_formatter_name != desired_formatter_name:
1409
+ current_handler.setFormatter(log_formatter)
1399
1410
 
1400
1411
  # The default function pod resource values are of type str; however, when reading from environment variable numbers,
1401
1412
  # it converts them to type int if contains only number, so we want to convert them to str.
@@ -83,14 +83,28 @@ class DatastoreProfileBasic(DatastoreProfile):
83
83
  class DatastoreProfileKafkaTarget(DatastoreProfile):
84
84
  type: str = pydantic.Field("kafka_target")
85
85
  _private_attributes = "kwargs_private"
86
- bootstrap_servers: str
87
- brokers: str
86
+ bootstrap_servers: typing.Optional[str] = None
87
+ brokers: typing.Optional[str] = None
88
88
  topic: str
89
89
  kwargs_public: typing.Optional[dict]
90
90
  kwargs_private: typing.Optional[dict]
91
91
 
92
- def __pydantic_post_init__(self):
92
+ def __init__(self, **kwargs):
93
+ super().__init__(**kwargs)
94
+
95
+ if not self.brokers and not self.bootstrap_servers:
96
+ raise mlrun.errors.MLRunInvalidArgumentError(
97
+ "DatastoreProfileKafkaTarget requires the 'brokers' field to be set"
98
+ )
99
+
93
100
  if self.bootstrap_servers:
101
+ if self.brokers:
102
+ raise mlrun.errors.MLRunInvalidArgumentError(
103
+ "DatastoreProfileKafkaTarget cannot be created with both 'brokers' and 'bootstrap_servers'"
104
+ )
105
+ else:
106
+ self.brokers = self.bootstrap_servers
107
+ self.bootstrap_servers = None
94
108
  warnings.warn(
95
109
  "'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
96
110
  "use 'brokers' instead.",
mlrun/datastore/hdfs.py CHANGED
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import os
15
+ from urllib.parse import urlparse
15
16
 
16
17
  import fsspec
17
18
 
@@ -49,3 +50,7 @@ class HdfsStore(DataStore):
49
50
  @property
50
51
  def spark_url(self):
51
52
  return f"hdfs://{self.host}:{self.port}"
53
+
54
+ def rm(self, url, recursive=False, maxdepth=None):
55
+ path = urlparse(url).path
56
+ self.filesystem.rm(path=path, recursive=recursive, maxdepth=maxdepth)
@@ -1390,6 +1390,39 @@ class RedisNoSqlTarget(NoSqlBaseTarget):
1390
1390
  support_spark = True
1391
1391
  writer_step_name = "RedisNoSqlTarget"
1392
1392
 
1393
+ @property
1394
+ def _target_path_object(self):
1395
+ url = self.path or mlrun.mlconf.redis.url
1396
+ if self._resource and url:
1397
+ parsed_url = urlparse(url)
1398
+ if not parsed_url.path or parsed_url.path == "/":
1399
+ kind_prefix = (
1400
+ "sets"
1401
+ if self._resource.kind
1402
+ == mlrun.common.schemas.ObjectKind.feature_set
1403
+ else "vectors"
1404
+ )
1405
+ kind = self.kind
1406
+ name = self._resource.metadata.name
1407
+ project = (
1408
+ self._resource.metadata.project or mlrun.mlconf.default_project
1409
+ )
1410
+ data_prefix = get_default_prefix_for_target(kind).format(
1411
+ ds_profile_name=parsed_url.netloc,
1412
+ authority=parsed_url.netloc,
1413
+ project=project,
1414
+ kind=kind,
1415
+ name=name,
1416
+ )
1417
+ if url.startswith("rediss://"):
1418
+ data_prefix = data_prefix.replace("redis://", "rediss://", 1)
1419
+ if not self.run_id:
1420
+ version = self._resource.metadata.tag or "latest"
1421
+ name = f"{name}-{version}"
1422
+ url = f"{data_prefix}/{kind_prefix}/{name}"
1423
+ return TargetPathObject(url, self.run_id, False)
1424
+ return super()._target_path_object
1425
+
1393
1426
  # Fetch server url from the RedisNoSqlTarget::__init__() 'path' parameter.
1394
1427
  # If not set fetch it from 'mlrun.mlconf.redis.url' (MLRUN_REDIS__URL environment variable).
1395
1428
  # Then look for username and password at REDIS_xxx secrets
@@ -1516,6 +1549,25 @@ class StreamTarget(BaseStoreTarget):
1516
1549
 
1517
1550
 
1518
1551
  class KafkaTarget(BaseStoreTarget):
1552
+ """
1553
+ Kafka target storage driver, used to write data into kafka topics.
1554
+ example::
1555
+ # define target
1556
+ kafka_target = KafkaTarget(
1557
+ name="kafka", path="my_topic", brokers="localhost:9092"
1558
+ )
1559
+ # ingest
1560
+ stocks_set.ingest(stocks, [kafka_target])
1561
+ :param name: target name
1562
+ :param path: topic name e.g. "my_topic"
1563
+ :param after_step: optional, after what step in the graph to add the target
1564
+ :param columns: optional, which columns from data to write
1565
+ :param bootstrap_servers: Deprecated. Use the brokers parameter instead
1566
+ :param producer_options: additional configurations for kafka producer
1567
+ :param brokers: kafka broker as represented by a host:port pair, or a list of kafka brokers, e.g.
1568
+ "localhost:9092", or ["kafka-broker-1:9092", "kafka-broker-2:9092"]
1569
+ """
1570
+
1519
1571
  kind = TargetTypes.kafka
1520
1572
  is_table = False
1521
1573
  is_online = False
mlrun/datastore/v3io.py CHANGED
@@ -29,7 +29,7 @@ from .base import (
29
29
  )
30
30
 
31
31
  V3IO_LOCAL_ROOT = "v3io"
32
- V3IO_DEFAULT_UPLOAD_CHUNK_SIZE = 1024 * 1024 * 100
32
+ V3IO_DEFAULT_UPLOAD_CHUNK_SIZE = 1024 * 1024 * 10
33
33
 
34
34
 
35
35
  class V3ioStore(DataStore):
mlrun/db/auth_utils.py ADDED
@@ -0,0 +1,152 @@
1
+ # Copyright 2024 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from datetime import datetime, timedelta
17
+
18
+ import requests
19
+
20
+ import mlrun.errors
21
+ from mlrun.utils import logger
22
+
23
+
24
+ class TokenProvider(ABC):
25
+ @abstractmethod
26
+ def get_token(self):
27
+ pass
28
+
29
+ @abstractmethod
30
+ def is_iguazio_session(self):
31
+ pass
32
+
33
+
34
+ class StaticTokenProvider(TokenProvider):
35
+ def __init__(self, token: str):
36
+ self.token = token
37
+
38
+ def get_token(self):
39
+ return self.token
40
+
41
+ def is_iguazio_session(self):
42
+ return mlrun.platforms.iguazio.is_iguazio_session(self.token)
43
+
44
+
45
+ class OAuthClientIDTokenProvider(TokenProvider):
46
+ def __init__(
47
+ self, token_endpoint: str, client_id: str, client_secret: str, timeout=5
48
+ ):
49
+ if not token_endpoint or not client_id or not client_secret:
50
+ raise mlrun.errors.MLRunValueError(
51
+ "Invalid client_id configuration for authentication. Must provide token endpoint, client-id and secret"
52
+ )
53
+ self.token_endpoint = token_endpoint
54
+ self.client_id = client_id
55
+ self.client_secret = client_secret
56
+ self.timeout = timeout
57
+
58
+ # Since we're only issuing POST requests, which are actually a disguised GET, then it's ok to allow retries
59
+ # on them.
60
+ self._session = mlrun.utils.HTTPSessionWithRetry(
61
+ retry_on_post=True,
62
+ verbose=True,
63
+ )
64
+
65
+ self._cleanup()
66
+ self._refresh_token_if_needed()
67
+
68
+ def get_token(self):
69
+ self._refresh_token_if_needed()
70
+ return self.token
71
+
72
+ def is_iguazio_session(self):
73
+ return False
74
+
75
+ def _cleanup(self):
76
+ self.token = self.token_expiry_time = self.token_refresh_time = None
77
+
78
+ def _refresh_token_if_needed(self):
79
+ now = datetime.now()
80
+ if self.token:
81
+ if self.token_refresh_time and now <= self.token_refresh_time:
82
+ return self.token
83
+
84
+ # We only cleanup if token was really expired - even if we fail in refreshing the token, we can still
85
+ # use the existing one given that it's not expired.
86
+ if now >= self.token_expiry_time:
87
+ self._cleanup()
88
+
89
+ self._issue_token_request()
90
+ return self.token
91
+
92
+ def _issue_token_request(self, raise_on_error=False):
93
+ try:
94
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
95
+ request_body = {
96
+ "grant_type": "client_credentials",
97
+ "client_id": self.client_id,
98
+ "client_secret": self.client_secret,
99
+ }
100
+ response = self._session.request(
101
+ "POST",
102
+ self.token_endpoint,
103
+ timeout=self.timeout,
104
+ headers=headers,
105
+ data=request_body,
106
+ )
107
+ except requests.RequestException as exc:
108
+ error = f"Retrieving token failed: {mlrun.errors.err_to_str(exc)}"
109
+ if raise_on_error:
110
+ raise mlrun.errors.MLRunRuntimeError(error) from exc
111
+ else:
112
+ logger.warning(error)
113
+ return
114
+
115
+ if not response.ok:
116
+ error = "No error available"
117
+ if response.content:
118
+ try:
119
+ data = response.json()
120
+ error = data.get("error")
121
+ except Exception:
122
+ pass
123
+ logger.warning(
124
+ "Retrieving token failed", status=response.status_code, error=error
125
+ )
126
+ if raise_on_error:
127
+ mlrun.errors.raise_for_status(response)
128
+ return
129
+
130
+ self._parse_response(response.json())
131
+
132
+ def _parse_response(self, data: dict):
133
+ # Response is described in https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.3
134
+ # According to spec, there isn't a refresh token - just the access token and its expiry time (in seconds).
135
+ self.token = data.get("access_token")
136
+ expires_in = data.get("expires_in")
137
+ if not self.token or not expires_in:
138
+ token_str = "****" if self.token else "missing"
139
+ logger.warning(
140
+ "Failed to parse token response", token=token_str, expires_in=expires_in
141
+ )
142
+ return
143
+
144
+ now = datetime.now()
145
+ self.token_expiry_time = now + timedelta(seconds=expires_in)
146
+ self.token_refresh_time = now + timedelta(seconds=expires_in / 2)
147
+ logger.info(
148
+ "Successfully retrieved client-id token",
149
+ expires_in=expires_in,
150
+ expiry=str(self.token_expiry_time),
151
+ refresh=str(self.token_refresh_time),
152
+ )
mlrun/db/base.py CHANGED
@@ -802,7 +802,7 @@ class RunDBInterface(ABC):
802
802
  project: str,
803
803
  base_period: int = 10,
804
804
  image: str = "mlrun/mlrun",
805
- ):
805
+ ) -> None:
806
806
  pass
807
807
 
808
808
  @abstractmethod
mlrun/db/httpdb.py CHANGED
@@ -38,6 +38,7 @@ import mlrun.platforms
38
38
  import mlrun.projects
39
39
  import mlrun.runtimes.nuclio.api_gateway
40
40
  import mlrun.utils
41
+ from mlrun.db.auth_utils import OAuthClientIDTokenProvider, StaticTokenProvider
41
42
  from mlrun.errors import MLRunInvalidArgumentError, err_to_str
42
43
 
43
44
  from ..artifacts import Artifact
@@ -138,17 +139,28 @@ class HTTPRunDB(RunDBInterface):
138
139
  endpoint += f":{parsed_url.port}"
139
140
  base_url = f"{parsed_url.scheme}://{endpoint}{parsed_url.path}"
140
141
 
142
+ self.base_url = base_url
141
143
  username = parsed_url.username or config.httpdb.user
142
144
  password = parsed_url.password or config.httpdb.password
145
+ self.token_provider = None
143
146
 
144
- username, password, token = mlrun.platforms.add_or_refresh_credentials(
145
- parsed_url.hostname, username, password, config.httpdb.token
146
- )
147
+ if config.auth_with_client_id.enabled:
148
+ self.token_provider = OAuthClientIDTokenProvider(
149
+ token_endpoint=mlrun.get_secret_or_env("MLRUN_AUTH_TOKEN_ENDPOINT"),
150
+ client_id=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_ID"),
151
+ client_secret=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_SECRET"),
152
+ timeout=config.auth_with_client_id.request_timeout,
153
+ )
154
+ else:
155
+ username, password, token = mlrun.platforms.add_or_refresh_credentials(
156
+ parsed_url.hostname, username, password, config.httpdb.token
157
+ )
158
+
159
+ if token:
160
+ self.token_provider = StaticTokenProvider(token)
147
161
 
148
- self.base_url = base_url
149
162
  self.user = username
150
163
  self.password = password
151
- self.token = token
152
164
 
153
165
  def __repr__(self):
154
166
  cls = self.__class__.__name__
@@ -218,17 +230,19 @@ class HTTPRunDB(RunDBInterface):
218
230
 
219
231
  if self.user:
220
232
  kw["auth"] = (self.user, self.password)
221
- elif self.token:
222
- # Iguazio auth doesn't support passing token through bearer, so use cookie instead
223
- if mlrun.platforms.iguazio.is_iguazio_session(self.token):
224
- session_cookie = f'j:{{"sid": "{self.token}"}}'
225
- cookies = {
226
- "session": session_cookie,
227
- }
228
- kw["cookies"] = cookies
229
- else:
230
- if "Authorization" not in kw.setdefault("headers", {}):
231
- kw["headers"].update({"Authorization": "Bearer " + self.token})
233
+ elif self.token_provider:
234
+ token = self.token_provider.get_token()
235
+ if token:
236
+ # Iguazio auth doesn't support passing token through bearer, so use cookie instead
237
+ if self.token_provider.is_iguazio_session():
238
+ session_cookie = f'j:{{"sid": "{token}"}}'
239
+ cookies = {
240
+ "session": session_cookie,
241
+ }
242
+ kw["cookies"] = cookies
243
+ else:
244
+ if "Authorization" not in kw.setdefault("headers", {}):
245
+ kw["headers"].update({"Authorization": "Bearer " + token})
232
246
 
233
247
  if mlrun.common.schemas.HeaderNames.client_version not in kw.setdefault(
234
248
  "headers", {}
@@ -1142,7 +1156,29 @@ class HTTPRunDB(RunDBInterface):
1142
1156
  project = project or config.default_project
1143
1157
  path = f"projects/{project}/functions/{name}"
1144
1158
  error_message = f"Failed deleting function {project}/{name}"
1145
- self.api_call("DELETE", path, error_message)
1159
+ response = self.api_call("DELETE", path, error_message, version="v2")
1160
+ if response.status_code == http.HTTPStatus.ACCEPTED:
1161
+ logger.info(
1162
+ "Function is being deleted", project_name=project, function_name=name
1163
+ )
1164
+ background_task = mlrun.common.schemas.BackgroundTask(**response.json())
1165
+ background_task = self._wait_for_background_task_to_reach_terminal_state(
1166
+ background_task.metadata.name, project=project
1167
+ )
1168
+ if (
1169
+ background_task.status.state
1170
+ == mlrun.common.schemas.BackgroundTaskState.succeeded
1171
+ ):
1172
+ logger.info(
1173
+ "Function deleted", project_name=project, function_name=name
1174
+ )
1175
+ elif (
1176
+ background_task.status.state
1177
+ == mlrun.common.schemas.BackgroundTaskState.failed
1178
+ ):
1179
+ logger.info(
1180
+ "Function deletion failed", project_name=project, function_name=name
1181
+ )
1146
1182
 
1147
1183
  def list_functions(self, name=None, project=None, tag=None, labels=None):
1148
1184
  """Retrieve a list of functions, filtered by specific criteria.
@@ -1488,16 +1524,15 @@ class HTTPRunDB(RunDBInterface):
1488
1524
  """
1489
1525
 
1490
1526
  try:
1527
+ normalized_name = normalize_name(func.metadata.name)
1491
1528
  params = {
1492
- "name": normalize_name(func.metadata.name),
1529
+ "name": normalized_name,
1493
1530
  "project": func.metadata.project,
1494
1531
  "tag": func.metadata.tag,
1495
1532
  "last_log_timestamp": str(last_log_timestamp),
1496
1533
  "verbose": bool2str(verbose),
1497
1534
  }
1498
- _path = (
1499
- f"projects/{func.metadata.project}/nuclio/{func.metadata.name}/deploy"
1500
- )
1535
+ _path = f"projects/{func.metadata.project}/nuclio/{normalized_name}/deploy"
1501
1536
  resp = self.api_call("GET", _path, params=params)
1502
1537
  except OSError as err:
1503
1538
  logger.error(f"error getting deploy status: {err_to_str(err)}")
@@ -3214,7 +3249,7 @@ class HTTPRunDB(RunDBInterface):
3214
3249
  project: str,
3215
3250
  base_period: int = 10,
3216
3251
  image: str = "mlrun/mlrun",
3217
- ):
3252
+ ) -> None:
3218
3253
  """
3219
3254
  Redeploy model monitoring application controller function.
3220
3255
 
@@ -3224,13 +3259,14 @@ class HTTPRunDB(RunDBInterface):
3224
3259
  :param image: The image of the model monitoring controller function.
3225
3260
  By default, the image is mlrun/mlrun.
3226
3261
  """
3227
-
3228
- params = {
3229
- "image": image,
3230
- "base_period": base_period,
3231
- }
3232
- path = f"projects/{project}/model-monitoring/model-monitoring-controller"
3233
- self.api_call(method="POST", path=path, params=params)
3262
+ self.api_call(
3263
+ method=mlrun.common.types.HTTPMethod.POST,
3264
+ path=f"projects/{project}/model-monitoring/model-monitoring-controller",
3265
+ params={
3266
+ "base_period": base_period,
3267
+ "image": image,
3268
+ },
3269
+ )
3234
3270
 
3235
3271
  def enable_model_monitoring(
3236
3272
  self,
mlrun/model.py CHANGED
@@ -766,6 +766,11 @@ class RunMetadata(ModelObj):
766
766
  def iteration(self, iteration):
767
767
  self._iteration = iteration
768
768
 
769
+ def is_workflow_runner(self):
770
+ if not self.labels:
771
+ return False
772
+ return self.labels.get("job-type", "") == "workflow-runner"
773
+
769
774
 
770
775
  class HyperParamStrategies:
771
776
  grid = "grid"
@@ -1218,6 +1223,19 @@ class RunStatus(ModelObj):
1218
1223
  self.reason = reason
1219
1224
  self.notifications = notifications or {}
1220
1225
 
1226
+ def is_failed(self) -> Optional[bool]:
1227
+ """
1228
+ This method returns whether a run has failed.
1229
+ Returns none if state has yet to be defined. callee is responsible for handling None.
1230
+ (e.g wait for state to be defined)
1231
+ """
1232
+ if not self.state:
1233
+ return None
1234
+ return self.state.casefold() in [
1235
+ mlrun.run.RunStatuses.failed.casefold(),
1236
+ mlrun.run.RunStatuses.error.casefold(),
1237
+ ]
1238
+
1221
1239
 
1222
1240
  class RunTemplate(ModelObj):
1223
1241
  """Run template"""
@@ -253,3 +253,10 @@ def calculate_inputs_statistics(
253
253
  )
254
254
 
255
255
  return inputs_statistics
256
+
257
+
258
+ def get_endpoint_record(project: str, endpoint_id: str):
259
+ model_endpoint_store = mlrun.model_monitoring.get_store_object(
260
+ project=project,
261
+ )
262
+ return model_endpoint_store.get_model_endpoint(endpoint_id=endpoint_id)
@@ -40,6 +40,7 @@ from mlrun.common.schemas.model_monitoring.constants import (
40
40
  ProjectSecretKeys,
41
41
  PrometheusEndpoints,
42
42
  )
43
+ from mlrun.model_monitoring.helpers import get_endpoint_record
43
44
  from mlrun.utils import logger
44
45
 
45
46
 
@@ -1233,13 +1234,6 @@ def update_endpoint_record(
1233
1234
  )
1234
1235
 
1235
1236
 
1236
- def get_endpoint_record(project: str, endpoint_id: str):
1237
- model_endpoint_store = mlrun.model_monitoring.get_store_object(
1238
- project=project,
1239
- )
1240
- return model_endpoint_store.get_model_endpoint(endpoint_id=endpoint_id)
1241
-
1242
-
1243
1237
  def update_monitoring_feature_set(
1244
1238
  endpoint_record: dict[str, typing.Any],
1245
1239
  feature_names: list[str],
@@ -27,8 +27,13 @@ import mlrun.common.schemas.alert as alert_constants
27
27
  import mlrun.model_monitoring
28
28
  import mlrun.model_monitoring.db.stores
29
29
  import mlrun.utils.v3io_clients
30
- from mlrun.common.schemas.model_monitoring.constants import ResultStatusApp, WriterEvent
30
+ from mlrun.common.schemas.model_monitoring.constants import (
31
+ EventFieldType,
32
+ ResultStatusApp,
33
+ WriterEvent,
34
+ )
31
35
  from mlrun.common.schemas.notification import NotificationKind, NotificationSeverity
36
+ from mlrun.model_monitoring.helpers import get_endpoint_record
32
37
  from mlrun.serving.utils import StepToDict
33
38
  from mlrun.utils import logger
34
39
  from mlrun.utils.notifications.notification_pusher import CustomNotificationPusher
@@ -112,6 +117,7 @@ class ModelMonitoringWriter(StepToDict):
112
117
  notification_types=[NotificationKind.slack]
113
118
  )
114
119
  self._create_tsdb_table()
120
+ self._endpoints_records = {}
115
121
 
116
122
  @staticmethod
117
123
  def get_v3io_container(project_name: str) -> str:
@@ -174,7 +180,7 @@ class ModelMonitoringWriter(StepToDict):
174
180
 
175
181
  @staticmethod
176
182
  def _generate_event_on_drift(
177
- uid: str, drift_status: str, drift_value: float, project_name: str
183
+ uid: str, drift_status: str, event_value: dict, project_name: str
178
184
  ):
179
185
  if (
180
186
  drift_status == ResultStatusApp.detected
@@ -191,7 +197,7 @@ class ModelMonitoringWriter(StepToDict):
191
197
  else alert_constants.EventKind.DRIFT_SUSPECTED
192
198
  )
193
199
  event_data = mlrun.common.schemas.Event(
194
- kind=event_kind, entity=entity, value=drift_value
200
+ kind=event_kind, entity=entity, value_dict=event_value
195
201
  )
196
202
  mlrun.get_run_db().generate_event(event_kind, event_data)
197
203
 
@@ -227,10 +233,22 @@ class ModelMonitoringWriter(StepToDict):
227
233
  _Notifier(event=event, notification_pusher=self._custom_notifier).notify()
228
234
 
229
235
  if mlrun.mlconf.alerts.mode == mlrun.common.schemas.alert.AlertsModes.enabled:
236
+ endpoint_id = event[WriterEvent.ENDPOINT_ID]
237
+ endpoint_record = self._endpoints_records.setdefault(
238
+ endpoint_id,
239
+ get_endpoint_record(project=self.project, endpoint_id=endpoint_id),
240
+ )
241
+ event_value = {
242
+ "app_name": event[WriterEvent.APPLICATION_NAME],
243
+ "model": endpoint_record.get(EventFieldType.MODEL),
244
+ "model_endpoint_id": event[WriterEvent.ENDPOINT_ID],
245
+ "result_name": event[WriterEvent.RESULT_NAME],
246
+ "result_value": event[WriterEvent.RESULT_VALUE],
247
+ }
230
248
  self._generate_event_on_drift(
231
249
  event[WriterEvent.ENDPOINT_ID],
232
250
  event[WriterEvent.RESULT_STATUS],
233
- event[WriterEvent.RESULT_VALUE],
251
+ event_value,
234
252
  self.project,
235
253
  )
236
254
  logger.info("Completed event DB writes")