mlrun 1.7.0rc17__py3-none-any.whl → 1.7.0rc18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (55) hide show
  1. mlrun/alerts/alert.py +1 -1
  2. mlrun/artifacts/manager.py +5 -1
  3. mlrun/common/runtimes/constants.py +3 -0
  4. mlrun/common/schemas/__init__.py +1 -1
  5. mlrun/common/schemas/alert.py +31 -9
  6. mlrun/common/schemas/client_spec.py +1 -0
  7. mlrun/common/schemas/function.py +4 -0
  8. mlrun/common/schemas/model_monitoring/__init__.py +3 -1
  9. mlrun/common/schemas/model_monitoring/constants.py +20 -1
  10. mlrun/common/schemas/model_monitoring/grafana.py +9 -5
  11. mlrun/common/schemas/model_monitoring/model_endpoints.py +17 -6
  12. mlrun/config.py +2 -0
  13. mlrun/data_types/to_pandas.py +5 -5
  14. mlrun/datastore/datastore.py +6 -2
  15. mlrun/datastore/redis.py +2 -2
  16. mlrun/datastore/s3.py +5 -0
  17. mlrun/datastore/sources.py +111 -6
  18. mlrun/datastore/targets.py +2 -2
  19. mlrun/db/base.py +5 -1
  20. mlrun/db/httpdb.py +22 -3
  21. mlrun/db/nopdb.py +5 -1
  22. mlrun/errors.py +6 -0
  23. mlrun/feature_store/retrieval/conversion.py +5 -5
  24. mlrun/feature_store/retrieval/job.py +3 -2
  25. mlrun/feature_store/retrieval/spark_merger.py +2 -1
  26. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +2 -2
  27. mlrun/model_monitoring/db/stores/base/store.py +16 -3
  28. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +44 -43
  29. mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +190 -91
  30. mlrun/model_monitoring/db/tsdb/__init__.py +35 -6
  31. mlrun/model_monitoring/db/tsdb/base.py +25 -18
  32. mlrun/model_monitoring/db/tsdb/tdengine/__init__.py +15 -0
  33. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +207 -0
  34. mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +45 -0
  35. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +231 -0
  36. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +73 -72
  37. mlrun/model_monitoring/db/v3io_tsdb_reader.py +217 -16
  38. mlrun/model_monitoring/helpers.py +32 -0
  39. mlrun/model_monitoring/stream_processing.py +7 -4
  40. mlrun/model_monitoring/writer.py +18 -13
  41. mlrun/package/utils/_formatter.py +2 -2
  42. mlrun/projects/project.py +33 -8
  43. mlrun/render.py +8 -5
  44. mlrun/runtimes/databricks_job/databricks_wrapper.py +1 -1
  45. mlrun/utils/async_http.py +25 -5
  46. mlrun/utils/helpers.py +20 -1
  47. mlrun/utils/notifications/notification/slack.py +27 -7
  48. mlrun/utils/notifications/notification_pusher.py +38 -40
  49. mlrun/utils/version/version.json +2 -2
  50. {mlrun-1.7.0rc17.dist-info → mlrun-1.7.0rc18.dist-info}/METADATA +7 -2
  51. {mlrun-1.7.0rc17.dist-info → mlrun-1.7.0rc18.dist-info}/RECORD +55 -51
  52. {mlrun-1.7.0rc17.dist-info → mlrun-1.7.0rc18.dist-info}/LICENSE +0 -0
  53. {mlrun-1.7.0rc17.dist-info → mlrun-1.7.0rc18.dist-info}/WHEEL +0 -0
  54. {mlrun-1.7.0rc17.dist-info → mlrun-1.7.0rc18.dist-info}/entry_points.txt +0 -0
  55. {mlrun-1.7.0rc17.dist-info → mlrun-1.7.0rc18.dist-info}/top_level.txt +0 -0
mlrun/alerts/alert.py CHANGED
@@ -137,7 +137,7 @@ class AlertConfig(ModelObj):
137
137
  template = db.get_alert_template(template)
138
138
 
139
139
  # Extract parameters from the template and apply them to the AlertConfig object
140
- self.description = template.description
140
+ self.summary = template.summary
141
141
  self.severity = template.severity
142
142
  self.criteria = template.criteria
143
143
  self.trigger = template.trigger
@@ -72,6 +72,10 @@ class ArtifactProducer:
72
72
  def get_meta(self) -> dict:
73
73
  return {"kind": self.kind, "name": self.name, "tag": self.tag}
74
74
 
75
+ @property
76
+ def uid(self):
77
+ return None
78
+
75
79
 
76
80
  def dict_to_artifact(struct: dict) -> Artifact:
77
81
  kind = struct.get("kind", "")
@@ -262,7 +266,7 @@ class ArtifactManager:
262
266
  if target_path and item.is_dir and not target_path.endswith("/"):
263
267
  target_path += "/"
264
268
  target_path = template_artifact_path(
265
- artifact_path=target_path, project=producer.project
269
+ artifact_path=target_path, project=producer.project, run_uid=producer.uid
266
270
  )
267
271
  item.target_path = target_path
268
272
 
@@ -136,6 +136,7 @@ class RunStates:
136
136
  unknown = "unknown"
137
137
  aborted = "aborted"
138
138
  aborting = "aborting"
139
+ skipped = "skipped"
139
140
 
140
141
  @staticmethod
141
142
  def all():
@@ -148,6 +149,7 @@ class RunStates:
148
149
  RunStates.unknown,
149
150
  RunStates.aborted,
150
151
  RunStates.aborting,
152
+ RunStates.skipped,
151
153
  ]
152
154
 
153
155
  @staticmethod
@@ -156,6 +158,7 @@ class RunStates:
156
158
  RunStates.completed,
157
159
  RunStates.error,
158
160
  RunStates.aborted,
161
+ RunStates.skipped,
159
162
  ]
160
163
 
161
164
  @staticmethod
@@ -148,10 +148,10 @@ from .model_monitoring import (
148
148
  ModelMonitoringMode,
149
149
  ModelMonitoringStoreKinds,
150
150
  MonitoringFunctionNames,
151
- MonitoringTSDBTables,
152
151
  PrometheusEndpoints,
153
152
  TimeSeriesConnector,
154
153
  TSDBTarget,
154
+ V3IOTSDBTables,
155
155
  )
156
156
  from .notification import (
157
157
  Notification,
@@ -22,7 +22,7 @@ from mlrun.common.types import StrEnum
22
22
 
23
23
 
24
24
  class EventEntityKind(StrEnum):
25
- MODEL = "model"
25
+ MODEL_ENDPOINT_RESULT = "model-endpoint-result"
26
26
  JOB = "job"
27
27
 
28
28
 
@@ -33,14 +33,34 @@ class EventEntities(pydantic.BaseModel):
33
33
 
34
34
 
35
35
  class EventKind(StrEnum):
36
- DRIFT_DETECTED = "drift_detected"
37
- DRIFT_SUSPECTED = "drift_suspected"
36
+ DATA_DRIFT_DETECTED = "data_drift_detected"
37
+ DATA_DRIFT_SUSPECTED = "data_drift_suspected"
38
+ CONCEPT_DRIFT_DETECTED = "concept_drift_detected"
39
+ CONCEPT_DRIFT_SUSPECTED = "concept_drift_suspected"
40
+ MODEL_PERFORMANCE_DETECTED = "model_performance_detected"
41
+ MODEL_PERFORMANCE_SUSPECTED = "model_performance_suspected"
42
+ MODEL_SERVING_PERFORMANCE_DETECTED = "model_serving_performance_detected"
43
+ MODEL_SERVING_PERFORMANCE_SUSPECTED = "model_serving_performance_suspected"
44
+ MM_APP_ANOMALY_DETECTED = "mm_app_anomaly_detected"
45
+ MM_APP_ANOMALY_SUSPECTED = "mm_app_anomaly_suspected"
38
46
  FAILED = "failed"
39
47
 
40
48
 
41
49
  _event_kind_entity_map = {
42
- EventKind.DRIFT_SUSPECTED: [EventEntityKind.MODEL],
43
- EventKind.DRIFT_DETECTED: [EventEntityKind.MODEL],
50
+ EventKind.DATA_DRIFT_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
51
+ EventKind.DATA_DRIFT_DETECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
52
+ EventKind.CONCEPT_DRIFT_DETECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
53
+ EventKind.CONCEPT_DRIFT_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
54
+ EventKind.MODEL_PERFORMANCE_DETECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
55
+ EventKind.MODEL_PERFORMANCE_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
56
+ EventKind.MODEL_SERVING_PERFORMANCE_DETECTED: [
57
+ EventEntityKind.MODEL_ENDPOINT_RESULT
58
+ ],
59
+ EventKind.MODEL_SERVING_PERFORMANCE_SUSPECTED: [
60
+ EventEntityKind.MODEL_ENDPOINT_RESULT
61
+ ],
62
+ EventKind.MM_APP_ANOMALY_DETECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
63
+ EventKind.MM_APP_ANOMALY_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
44
64
  EventKind.FAILED: [EventEntityKind.JOB],
45
65
  }
46
66
 
@@ -123,7 +143,8 @@ class AlertConfig(pydantic.BaseModel):
123
143
  pydantic.Field(
124
144
  description=(
125
145
  "String to be sent in the notifications generated."
126
- "e.g. 'Model {{ $project }}/{{ $entity }} is drifting.'"
146
+ "e.g. 'Model {{project}}/{{entity}} is drifting.'"
147
+ "Supported variables: project, entity, name"
127
148
  )
128
149
  ),
129
150
  ]
@@ -161,8 +182,9 @@ class AlertTemplate(
161
182
  system_generated: bool = False
162
183
 
163
184
  # AlertConfig fields that are pre-defined
164
- description: Optional[str] = (
165
- "String to be sent in the generated notifications e.g. 'Model {{ $project }}/{{ $entity }} is drifting.'"
185
+ summary: Optional[str] = (
186
+ "String to be sent in the generated notifications e.g. 'Model {{project}}/{{entity}} is drifting.'"
187
+ "See AlertConfig.summary description"
166
188
  )
167
189
  severity: AlertSeverity
168
190
  trigger: AlertTrigger
@@ -173,7 +195,7 @@ class AlertTemplate(
173
195
  def templates_differ(self, other):
174
196
  return (
175
197
  self.template_description != other.template_description
176
- or self.description != other.description
198
+ or self.summary != other.summary
177
199
  or self.severity != other.severity
178
200
  or self.trigger != other.trigger
179
201
  or self.reset_policy != other.reset_policy
@@ -59,6 +59,7 @@ class ClientSpec(pydantic.BaseModel):
59
59
  sql_url: typing.Optional[str]
60
60
  model_endpoint_monitoring_store_type: typing.Optional[str]
61
61
  model_endpoint_monitoring_endpoint_store_connection: typing.Optional[str]
62
+ model_monitoring_tsdb_connection: typing.Optional[str]
62
63
  ce: typing.Optional[dict]
63
64
  # not passing them as one object as it possible client user would like to override only one of the params
64
65
  calculate_artifact_hash: typing.Optional[str]
@@ -45,6 +45,9 @@ class FunctionState:
45
45
  # same goes for the build which is not coming from the pod, but is used and we can't just omit it for BC reasons
46
46
  build = "build"
47
47
 
48
+ # for pipeline steps
49
+ skipped = "skipped"
50
+
48
51
  @classmethod
49
52
  def get_function_state_from_pod_state(cls, pod_state: str):
50
53
  if pod_state == "succeeded":
@@ -60,6 +63,7 @@ class FunctionState:
60
63
  return [
61
64
  cls.ready,
62
65
  cls.error,
66
+ cls.skipped,
63
67
  ]
64
68
 
65
69
 
@@ -30,20 +30,22 @@ from .constants import (
30
30
  ModelMonitoringMode,
31
31
  ModelMonitoringStoreKinds,
32
32
  MonitoringFunctionNames,
33
- MonitoringTSDBTables,
34
33
  ProjectSecretKeys,
35
34
  PrometheusEndpoints,
36
35
  PrometheusMetric,
37
36
  ResultData,
38
37
  SchedulingKeys,
38
+ TDEngineSuperTables,
39
39
  TimeSeriesConnector,
40
40
  TSDBTarget,
41
+ V3IOTSDBTables,
41
42
  VersionedModel,
42
43
  WriterEvent,
43
44
  WriterEventKind,
44
45
  )
45
46
  from .grafana import (
46
47
  GrafanaColumn,
48
+ GrafanaColumnType,
47
49
  GrafanaDataPoint,
48
50
  GrafanaNumberColumn,
49
51
  GrafanaStringColumn,
@@ -81,6 +81,8 @@ class EventFieldType:
81
81
  DRIFT_DETECTED_THRESHOLD = "drift_detected_threshold"
82
82
  POSSIBLE_DRIFT_THRESHOLD = "possible_drift_threshold"
83
83
  SAMPLE_PARQUET_PATH = "sample_parquet_path"
84
+ TIME = "time"
85
+ TABLE_COLUMN = "table_column"
84
86
 
85
87
 
86
88
  class FeatureSetFeatures(MonitoringStrEnum):
@@ -171,6 +173,7 @@ class ProjectSecretKeys:
171
173
  PIPELINES_ACCESS_KEY = "MODEL_MONITORING_PIPELINES_ACCESS_KEY"
172
174
  KAFKA_BROKERS = "KAFKA_BROKERS"
173
175
  STREAM_PATH = "STREAM_PATH"
176
+ TSDB_CONNECTION = "TSDB_CONNECTION"
174
177
 
175
178
 
176
179
  class ModelMonitoringStoreKinds:
@@ -230,12 +233,18 @@ class MonitoringFunctionNames(MonitoringStrEnum):
230
233
  WRITER = "model-monitoring-writer"
231
234
 
232
235
 
233
- class MonitoringTSDBTables(MonitoringStrEnum):
236
+ class V3IOTSDBTables(MonitoringStrEnum):
234
237
  APP_RESULTS = "app-results"
235
238
  METRICS = "metrics"
236
239
  EVENTS = "events"
237
240
 
238
241
 
242
+ class TDEngineSuperTables(MonitoringStrEnum):
243
+ APP_RESULTS = "app_results"
244
+ METRICS = "metrics"
245
+ PREDICTIONS = "predictions"
246
+
247
+
239
248
  @dataclass
240
249
  class FunctionURI:
241
250
  project: str
@@ -339,6 +348,7 @@ class ControllerPolicy:
339
348
 
340
349
  class TSDBTarget:
341
350
  V3IO_TSDB = "v3io-tsdb"
351
+ TDEngine = "tdengine"
342
352
  PROMETHEUS = "prometheus"
343
353
  APP_RESULTS_TABLE = "app-results"
344
354
  V3IO_BE = "tsdb"
@@ -348,3 +358,12 @@ class TSDBTarget:
348
358
  class HistogramDataDriftApplicationConstants:
349
359
  NAME = "histogram-data-drift"
350
360
  GENERAL_RESULT_NAME = "general_drift"
361
+
362
+
363
+ class PredictionsQueryConstants:
364
+ DEFAULT_AGGREGATION_GRANULARITY = "10m"
365
+ INVOCATIONS = "invocations"
366
+
367
+
368
+ class SpecialApps:
369
+ MLRUN_INFRA = "mlrun-infra"
@@ -11,12 +11,18 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- #
15
14
 
16
15
  from typing import Optional, Union
17
16
 
18
17
  from pydantic import BaseModel
19
18
 
19
+ import mlrun.common.types
20
+
21
+
22
+ class GrafanaColumnType(mlrun.common.types.StrEnum):
23
+ NUMBER = "number"
24
+ STRING = "string"
25
+
20
26
 
21
27
  class GrafanaColumn(BaseModel):
22
28
  text: str
@@ -24,13 +30,11 @@ class GrafanaColumn(BaseModel):
24
30
 
25
31
 
26
32
  class GrafanaNumberColumn(GrafanaColumn):
27
- text: str
28
- type: str = "number"
33
+ type: str = GrafanaColumnType.NUMBER
29
34
 
30
35
 
31
36
  class GrafanaStringColumn(GrafanaColumn):
32
- text: str
33
- type: str = "string"
37
+ type: str = GrafanaColumnType.STRING
34
38
 
35
39
 
36
40
  class GrafanaTable(BaseModel):
@@ -298,6 +298,7 @@ class ModelEndpointList(BaseModel):
298
298
 
299
299
  class ModelEndpointMonitoringMetricType(mlrun.common.types.StrEnum):
300
300
  RESULT = "result"
301
+ METRIC = "metric"
301
302
 
302
303
 
303
304
  class ModelEndpointMonitoringMetric(BaseModel):
@@ -322,7 +323,7 @@ _FQN_PART_PATTERN = r"[a-zA-Z0-9_-]+"
322
323
  _FQN_PATTERN = (
323
324
  rf"^(?P<project>{_FQN_PART_PATTERN})\."
324
325
  rf"(?P<app>{_FQN_PART_PATTERN})\."
325
- rf"(?P<type>{_FQN_PART_PATTERN})\."
326
+ rf"(?P<type>{ModelEndpointMonitoringMetricType.RESULT}|{ModelEndpointMonitoringMetricType.METRIC})\."
326
327
  rf"(?P<name>{_FQN_PART_PATTERN})$"
327
328
  )
328
329
  _FQN_REGEX = re.compile(_FQN_PATTERN)
@@ -337,27 +338,37 @@ def _parse_metric_fqn_to_monitoring_metric(fqn: str) -> ModelEndpointMonitoringM
337
338
  )
338
339
 
339
340
 
341
+ class _MetricPoint(NamedTuple):
342
+ timestamp: datetime
343
+ value: float
344
+
345
+
340
346
  class _ResultPoint(NamedTuple):
341
347
  timestamp: datetime
342
348
  value: float
343
349
  status: ResultStatusApp
344
350
 
345
351
 
346
- class _ModelEndpointMonitoringResultValuesBase(BaseModel):
352
+ class _ModelEndpointMonitoringMetricValuesBase(BaseModel):
347
353
  full_name: str
348
354
  type: ModelEndpointMonitoringMetricType
349
355
  data: bool
350
356
 
351
357
 
352
- class ModelEndpointMonitoringResultValues(_ModelEndpointMonitoringResultValuesBase):
353
- full_name: str
354
- type: ModelEndpointMonitoringMetricType
358
+ class ModelEndpointMonitoringMetricValues(_ModelEndpointMonitoringMetricValuesBase):
359
+ type: ModelEndpointMonitoringMetricType = ModelEndpointMonitoringMetricType.METRIC
360
+ values: list[_MetricPoint]
361
+ data: bool = True
362
+
363
+
364
+ class ModelEndpointMonitoringResultValues(_ModelEndpointMonitoringMetricValuesBase):
365
+ type: ModelEndpointMonitoringMetricType = ModelEndpointMonitoringMetricType.RESULT
355
366
  result_kind: ResultKindApp
356
367
  values: list[_ResultPoint]
357
368
  data: bool = True
358
369
 
359
370
 
360
- class ModelEndpointMonitoringResultNoData(_ModelEndpointMonitoringResultValuesBase):
371
+ class ModelEndpointMonitoringMetricNoData(_ModelEndpointMonitoringMetricValuesBase):
361
372
  full_name: str
362
373
  type: ModelEndpointMonitoringMetricType
363
374
  data: bool = False
mlrun/config.py CHANGED
@@ -521,7 +521,9 @@ default_config = {
521
521
  # See mlrun.model_monitoring.db.stores.ObjectStoreFactory for available options
522
522
  "store_type": "v3io-nosql",
523
523
  "endpoint_store_connection": "",
524
+ # See mlrun.model_monitoring.db.tsdb.ObjectTSDBFactory for available options
524
525
  "tsdb_connector_type": "v3io-tsdb",
526
+ "tsdb_connection": "",
525
527
  },
526
528
  "secret_stores": {
527
529
  # Use only in testing scenarios (such as integration tests) to avoid using k8s for secrets (will use in-memory
@@ -154,10 +154,10 @@ def toPandas(spark_df):
154
154
  column_counter = Counter(spark_df.columns)
155
155
 
156
156
  dtype = [None] * len(spark_df.schema)
157
- for fieldIdx, field in enumerate(spark_df.schema):
157
+ for field_idx, field in enumerate(spark_df.schema):
158
158
  # For duplicate column name, we use `iloc` to access it.
159
159
  if column_counter[field.name] > 1:
160
- pandas_col = pdf.iloc[:, fieldIdx]
160
+ pandas_col = pdf.iloc[:, field_idx]
161
161
  else:
162
162
  pandas_col = pdf[field.name]
163
163
 
@@ -171,12 +171,12 @@ def toPandas(spark_df):
171
171
  and field.nullable
172
172
  and pandas_col.isnull().any()
173
173
  ):
174
- dtype[fieldIdx] = pandas_type
174
+ dtype[field_idx] = pandas_type
175
175
  # Ensure we fall back to nullable numpy types, even when whole column is null:
176
176
  if isinstance(field.dataType, IntegralType) and pandas_col.isnull().any():
177
- dtype[fieldIdx] = np.float64
177
+ dtype[field_idx] = np.float64
178
178
  if isinstance(field.dataType, BooleanType) and pandas_col.isnull().any():
179
- dtype[fieldIdx] = object
179
+ dtype[field_idx] = object
180
180
 
181
181
  df = pd.DataFrame()
182
182
  for index, t in enumerate(dtype):
@@ -223,6 +223,11 @@ class StoreManager:
223
223
  subpath = url[len("memory://") :]
224
224
  return in_memory_store, subpath, url
225
225
 
226
+ elif schema in get_local_file_schema():
227
+ # parse_url() will drop the windows drive-letter from the path for url like "c:\a\b".
228
+ # As a workaround, we set subpath to the url.
229
+ subpath = url.replace("file://", "", 1)
230
+
226
231
  if not schema and endpoint:
227
232
  if endpoint in self._stores.keys():
228
233
  return self._stores[endpoint], subpath, url
@@ -241,8 +246,7 @@ class StoreManager:
241
246
  )
242
247
  if not secrets and not mlrun.config.is_running_as_api():
243
248
  self._stores[store_key] = store
244
- # in file stores in windows path like c:\a\b the drive letter is dropped from the path, so we return the url
245
- return store, url if store.kind == "file" else subpath, url
249
+ return store, subpath, url
246
250
 
247
251
  def reset_secrets(self):
248
252
  self._secrets = {}
mlrun/datastore/redis.py CHANGED
@@ -31,7 +31,7 @@ class RedisStore(DataStore):
31
31
  """
32
32
 
33
33
  def __init__(self, parent, schema, name, endpoint="", secrets: dict = None):
34
- REDIS_DEFAULT_PORT = "6379"
34
+ redis_default_port = "6379"
35
35
  super().__init__(parent, name, schema, endpoint, secrets=secrets)
36
36
  self.headers = None
37
37
 
@@ -49,7 +49,7 @@ class RedisStore(DataStore):
49
49
  user = self._get_secret_or_env("REDIS_USER", "", credentials_prefix)
50
50
  password = self._get_secret_or_env("REDIS_PASSWORD", "", credentials_prefix)
51
51
  host = parsed_endpoint.hostname
52
- port = parsed_endpoint.port if parsed_endpoint.port else REDIS_DEFAULT_PORT
52
+ port = parsed_endpoint.port if parsed_endpoint.port else redis_default_port
53
53
  schema = parsed_endpoint.scheme
54
54
  if user or password:
55
55
  endpoint = f"{schema}://{user}:{password}@{host}:{port}"
mlrun/datastore/s3.py CHANGED
@@ -198,6 +198,11 @@ class S3Store(DataStore):
198
198
  bucket = self.s3.Bucket(bucket)
199
199
  return [obj.key[key_length:] for obj in bucket.objects.filter(Prefix=key)]
200
200
 
201
+ def rm(self, path, recursive=False, maxdepth=None):
202
+ bucket, key = self.get_bucket_and_key(path)
203
+ path = f"{bucket}/{key}"
204
+ self.filesystem.rm(path=path, recursive=recursive, maxdepth=maxdepth)
205
+
201
206
 
202
207
  def parse_s3_bucket_and_key(s3_path):
203
208
  try:
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import json
15
+ import math
16
+ import operator
15
17
  import os
16
18
  import warnings
17
19
  from base64 import b64encode
@@ -178,7 +180,7 @@ class CSVSource(BaseSourceDriver):
178
180
  self,
179
181
  name: str = "",
180
182
  path: str = None,
181
- attributes: dict[str, str] = None,
183
+ attributes: dict[str, object] = None,
182
184
  key_field: str = None,
183
185
  schedule: str = None,
184
186
  parse_dates: Union[None, int, str, list[int], list[str]] = None,
@@ -305,7 +307,7 @@ class ParquetSource(BaseSourceDriver):
305
307
  self,
306
308
  name: str = "",
307
309
  path: str = None,
308
- attributes: dict[str, str] = None,
310
+ attributes: dict[str, object] = None,
309
311
  key_field: str = None,
310
312
  time_field: str = None,
311
313
  schedule: str = None,
@@ -313,6 +315,10 @@ class ParquetSource(BaseSourceDriver):
313
315
  end_time: Optional[Union[datetime, str]] = None,
314
316
  additional_filters: Optional[list[tuple]] = None,
315
317
  ):
318
+ if additional_filters:
319
+ attributes = copy(attributes) or {}
320
+ attributes["additional_filters"] = additional_filters
321
+ self.validate_additional_filters(additional_filters)
316
322
  super().__init__(
317
323
  name,
318
324
  path,
@@ -323,7 +329,6 @@ class ParquetSource(BaseSourceDriver):
323
329
  start_time,
324
330
  end_time,
325
331
  )
326
- self.additional_filters = additional_filters
327
332
 
328
333
  @property
329
334
  def start_time(self):
@@ -341,6 +346,10 @@ class ParquetSource(BaseSourceDriver):
341
346
  def end_time(self, end_time):
342
347
  self._end_time = self._convert_to_datetime(end_time)
343
348
 
349
+ @property
350
+ def additional_filters(self):
351
+ return self.attributes.get("additional_filters")
352
+
344
353
  @staticmethod
345
354
  def _convert_to_datetime(time):
346
355
  if time and isinstance(time, str):
@@ -350,6 +359,25 @@ class ParquetSource(BaseSourceDriver):
350
359
  else:
351
360
  return time
352
361
 
362
+ @staticmethod
363
+ def validate_additional_filters(additional_filters):
364
+ if not additional_filters:
365
+ return
366
+ for filter_tuple in additional_filters:
367
+ if not filter_tuple:
368
+ continue
369
+ col_name, op, value = filter_tuple
370
+ if isinstance(value, float) and math.isnan(value):
371
+ raise mlrun.errors.MLRunInvalidArgumentError(
372
+ "using NaN in additional_filters is not supported"
373
+ )
374
+ elif isinstance(value, (list, tuple, set)):
375
+ for sub_value in value:
376
+ if isinstance(sub_value, float) and math.isnan(sub_value):
377
+ raise mlrun.errors.MLRunInvalidArgumentError(
378
+ "using NaN in additional_filters is not supported"
379
+ )
380
+
353
381
  def to_step(
354
382
  self,
355
383
  key_field=None,
@@ -361,13 +389,12 @@ class ParquetSource(BaseSourceDriver):
361
389
  ):
362
390
  import storey
363
391
 
364
- attributes = self.attributes or {}
392
+ attributes = copy(self.attributes)
393
+ attributes.pop("additional_filters", None)
365
394
  if context:
366
395
  attributes["context"] = context
367
-
368
396
  data_item = mlrun.store_manager.object(self.path)
369
397
  store, path, url = mlrun.store_manager.get_or_create_store(self.path)
370
-
371
398
  return storey.ParquetSource(
372
399
  paths=url, # unlike self.path, it already has store:// replaced
373
400
  key_field=self.key_field or key_field,
@@ -412,6 +439,84 @@ class ParquetSource(BaseSourceDriver):
412
439
  **reader_args,
413
440
  )
414
441
 
442
+ def _build_spark_additional_filters(self, column_types: dict):
443
+ if not self.additional_filters:
444
+ return None
445
+ from pyspark.sql.functions import col, isnan, lit
446
+
447
+ operators = {
448
+ "==": operator.eq,
449
+ "=": operator.eq,
450
+ ">": operator.gt,
451
+ "<": operator.lt,
452
+ ">=": operator.ge,
453
+ "<=": operator.le,
454
+ "!=": operator.ne,
455
+ }
456
+
457
+ spark_filter = None
458
+ new_filter = lit(True)
459
+ for filter_tuple in self.additional_filters:
460
+ if not filter_tuple:
461
+ continue
462
+ col_name, op, value = filter_tuple
463
+ if op.lower() in ("in", "not in") and isinstance(value, (list, tuple, set)):
464
+ none_exists = False
465
+ value = list(value)
466
+ for sub_value in value:
467
+ if sub_value is None:
468
+ value.remove(sub_value)
469
+ none_exists = True
470
+ if none_exists:
471
+ filter_nan = column_types[col_name] not in ("timestamp", "date")
472
+ if value:
473
+ if op.lower() == "in":
474
+ new_filter = (
475
+ col(col_name).isin(value) | col(col_name).isNull()
476
+ )
477
+ if filter_nan:
478
+ new_filter = new_filter | isnan(col(col_name))
479
+
480
+ else:
481
+ new_filter = (
482
+ ~col(col_name).isin(value) & ~col(col_name).isNull()
483
+ )
484
+ if filter_nan:
485
+ new_filter = new_filter & ~isnan(col(col_name))
486
+ else:
487
+ if op.lower() == "in":
488
+ new_filter = col(col_name).isNull()
489
+ if filter_nan:
490
+ new_filter = new_filter | isnan(col(col_name))
491
+ else:
492
+ new_filter = ~col(col_name).isNull()
493
+ if filter_nan:
494
+ new_filter = new_filter & ~isnan(col(col_name))
495
+ else:
496
+ if op.lower() == "in":
497
+ new_filter = col(col_name).isin(value)
498
+ elif op.lower() == "not in":
499
+ new_filter = ~col(col_name).isin(value)
500
+ elif op in operators:
501
+ new_filter = operators[op](col(col_name), value)
502
+ else:
503
+ raise mlrun.errors.MLRunInvalidArgumentError(
504
+ f"unsupported filter operator: {op}"
505
+ )
506
+ if spark_filter is not None:
507
+ spark_filter = spark_filter & new_filter
508
+ else:
509
+ spark_filter = new_filter
510
+ return spark_filter
511
+
512
+ def _filter_spark_df(self, df, time_field=None, columns=None):
513
+ spark_additional_filters = self._build_spark_additional_filters(
514
+ column_types=dict(df.dtypes)
515
+ )
516
+ if spark_additional_filters is not None:
517
+ df = df.filter(spark_additional_filters)
518
+ return super()._filter_spark_df(df=df, time_field=time_field, columns=columns)
519
+
415
520
 
416
521
  class BigQuerySource(BaseSourceDriver):
417
522
  """
@@ -2134,7 +2134,7 @@ class SQLTarget(BaseStoreTarget):
2134
2134
  raise ValueError(f"Table named {table_name} is not exist")
2135
2135
 
2136
2136
  elif not table_exists and create_table:
2137
- TYPE_TO_SQL_TYPE = {
2137
+ type_to_sql_type = {
2138
2138
  int: sqlalchemy.Integer,
2139
2139
  str: sqlalchemy.String(self.attributes.get("varchar_len")),
2140
2140
  datetime.datetime: sqlalchemy.dialects.mysql.DATETIME(fsp=6),
@@ -2147,7 +2147,7 @@ class SQLTarget(BaseStoreTarget):
2147
2147
  # creat new table with the given name
2148
2148
  columns = []
2149
2149
  for col, col_type in self.schema.items():
2150
- col_type_sql = TYPE_TO_SQL_TYPE.get(col_type)
2150
+ col_type_sql = type_to_sql_type.get(col_type)
2151
2151
  if col_type_sql is None:
2152
2152
  raise TypeError(
2153
2153
  f"'{col_type}' unsupported type for column '{col}'"
mlrun/db/base.py CHANGED
@@ -17,6 +17,7 @@ from abc import ABC, abstractmethod
17
17
  from typing import Optional, Union
18
18
 
19
19
  import mlrun.alerts
20
+ import mlrun.common.runtimes.constants
20
21
  import mlrun.common.schemas
21
22
  import mlrun.model_monitoring
22
23
 
@@ -63,7 +64,10 @@ class RunDBInterface(ABC):
63
64
  uid: Optional[Union[str, list[str]]] = None,
64
65
  project: Optional[str] = None,
65
66
  labels: Optional[Union[str, list[str]]] = None,
66
- state: Optional[str] = None,
67
+ state: Optional[
68
+ mlrun.common.runtimes.constants.RunStates
69
+ ] = None, # Backward compatibility
70
+ states: Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
67
71
  sort: bool = True,
68
72
  last: int = 0,
69
73
  iter: bool = False,