mlrun 1.7.2rc3__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +26 -22
- mlrun/__main__.py +15 -16
- mlrun/alerts/alert.py +150 -15
- mlrun/api/schemas/__init__.py +1 -9
- mlrun/artifacts/__init__.py +2 -3
- mlrun/artifacts/base.py +62 -19
- mlrun/artifacts/dataset.py +17 -17
- mlrun/artifacts/document.py +454 -0
- mlrun/artifacts/manager.py +28 -18
- mlrun/artifacts/model.py +91 -59
- mlrun/artifacts/plots.py +2 -2
- mlrun/common/constants.py +8 -0
- mlrun/common/formatters/__init__.py +1 -0
- mlrun/common/formatters/artifact.py +1 -1
- mlrun/common/formatters/feature_set.py +2 -0
- mlrun/common/formatters/function.py +1 -0
- mlrun/{model_monitoring/db/stores/v3io_kv/__init__.py → common/formatters/model_endpoint.py} +17 -0
- mlrun/common/formatters/pipeline.py +1 -2
- mlrun/common/formatters/project.py +9 -0
- mlrun/common/model_monitoring/__init__.py +0 -5
- mlrun/common/model_monitoring/helpers.py +12 -62
- mlrun/common/runtimes/constants.py +25 -4
- mlrun/common/schemas/__init__.py +9 -5
- mlrun/common/schemas/alert.py +114 -19
- mlrun/common/schemas/api_gateway.py +3 -3
- mlrun/common/schemas/artifact.py +22 -9
- mlrun/common/schemas/auth.py +8 -4
- mlrun/common/schemas/background_task.py +7 -7
- mlrun/common/schemas/client_spec.py +4 -4
- mlrun/common/schemas/clusterization_spec.py +2 -2
- mlrun/common/schemas/common.py +53 -3
- mlrun/common/schemas/constants.py +15 -0
- mlrun/common/schemas/datastore_profile.py +1 -1
- mlrun/common/schemas/feature_store.py +9 -9
- mlrun/common/schemas/frontend_spec.py +4 -4
- mlrun/common/schemas/function.py +10 -10
- mlrun/common/schemas/hub.py +1 -1
- mlrun/common/schemas/k8s.py +3 -3
- mlrun/common/schemas/memory_reports.py +3 -3
- mlrun/common/schemas/model_monitoring/__init__.py +4 -8
- mlrun/common/schemas/model_monitoring/constants.py +127 -46
- mlrun/common/schemas/model_monitoring/grafana.py +18 -12
- mlrun/common/schemas/model_monitoring/model_endpoints.py +154 -160
- mlrun/common/schemas/notification.py +24 -3
- mlrun/common/schemas/object.py +1 -1
- mlrun/common/schemas/pagination.py +4 -4
- mlrun/common/schemas/partition.py +142 -0
- mlrun/common/schemas/pipeline.py +3 -3
- mlrun/common/schemas/project.py +26 -18
- mlrun/common/schemas/runs.py +3 -3
- mlrun/common/schemas/runtime_resource.py +5 -5
- mlrun/common/schemas/schedule.py +1 -1
- mlrun/common/schemas/secret.py +1 -1
- mlrun/{model_monitoring/db/stores/sqldb/__init__.py → common/schemas/serving.py} +10 -1
- mlrun/common/schemas/tag.py +3 -3
- mlrun/common/schemas/workflow.py +6 -5
- mlrun/common/types.py +1 -0
- mlrun/config.py +157 -89
- mlrun/data_types/__init__.py +5 -3
- mlrun/data_types/infer.py +13 -3
- mlrun/data_types/spark.py +2 -1
- mlrun/datastore/__init__.py +59 -18
- mlrun/datastore/alibaba_oss.py +4 -1
- mlrun/datastore/azure_blob.py +4 -1
- mlrun/datastore/base.py +19 -24
- mlrun/datastore/datastore.py +10 -4
- mlrun/datastore/datastore_profile.py +178 -45
- mlrun/datastore/dbfs_store.py +4 -1
- mlrun/datastore/filestore.py +4 -1
- mlrun/datastore/google_cloud_storage.py +4 -1
- mlrun/datastore/hdfs.py +4 -1
- mlrun/datastore/inmem.py +4 -1
- mlrun/datastore/redis.py +4 -1
- mlrun/datastore/s3.py +14 -3
- mlrun/datastore/sources.py +89 -92
- mlrun/datastore/store_resources.py +7 -4
- mlrun/datastore/storeytargets.py +51 -16
- mlrun/datastore/targets.py +38 -31
- mlrun/datastore/utils.py +87 -4
- mlrun/datastore/v3io.py +4 -1
- mlrun/datastore/vectorstore.py +291 -0
- mlrun/datastore/wasbfs/fs.py +13 -12
- mlrun/db/base.py +286 -100
- mlrun/db/httpdb.py +1562 -490
- mlrun/db/nopdb.py +250 -83
- mlrun/errors.py +6 -2
- mlrun/execution.py +194 -50
- mlrun/feature_store/__init__.py +2 -10
- mlrun/feature_store/api.py +20 -458
- mlrun/feature_store/common.py +9 -9
- mlrun/feature_store/feature_set.py +20 -18
- mlrun/feature_store/feature_vector.py +105 -479
- mlrun/feature_store/feature_vector_utils.py +466 -0
- mlrun/feature_store/retrieval/base.py +15 -11
- mlrun/feature_store/retrieval/job.py +2 -1
- mlrun/feature_store/retrieval/storey_merger.py +1 -1
- mlrun/feature_store/steps.py +3 -3
- mlrun/features.py +30 -13
- mlrun/frameworks/__init__.py +1 -2
- mlrun/frameworks/_common/__init__.py +1 -2
- mlrun/frameworks/_common/artifacts_library.py +2 -2
- mlrun/frameworks/_common/mlrun_interface.py +10 -6
- mlrun/frameworks/_common/model_handler.py +31 -31
- mlrun/frameworks/_common/producer.py +3 -1
- mlrun/frameworks/_dl_common/__init__.py +1 -2
- mlrun/frameworks/_dl_common/loggers/__init__.py +1 -2
- mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +4 -4
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +3 -3
- mlrun/frameworks/_ml_common/__init__.py +1 -2
- mlrun/frameworks/_ml_common/loggers/__init__.py +1 -2
- mlrun/frameworks/_ml_common/model_handler.py +21 -21
- mlrun/frameworks/_ml_common/plans/__init__.py +1 -2
- mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +3 -1
- mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
- mlrun/frameworks/auto_mlrun/__init__.py +1 -2
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +22 -15
- mlrun/frameworks/huggingface/__init__.py +1 -2
- mlrun/frameworks/huggingface/model_server.py +9 -9
- mlrun/frameworks/lgbm/__init__.py +47 -44
- mlrun/frameworks/lgbm/callbacks/__init__.py +1 -2
- mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -2
- mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -2
- mlrun/frameworks/lgbm/mlrun_interfaces/__init__.py +1 -2
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +5 -5
- mlrun/frameworks/lgbm/model_handler.py +15 -11
- mlrun/frameworks/lgbm/model_server.py +11 -7
- mlrun/frameworks/lgbm/utils.py +2 -2
- mlrun/frameworks/onnx/__init__.py +1 -2
- mlrun/frameworks/onnx/dataset.py +3 -3
- mlrun/frameworks/onnx/mlrun_interface.py +2 -2
- mlrun/frameworks/onnx/model_handler.py +7 -5
- mlrun/frameworks/onnx/model_server.py +8 -6
- mlrun/frameworks/parallel_coordinates.py +11 -11
- mlrun/frameworks/pytorch/__init__.py +22 -23
- mlrun/frameworks/pytorch/callbacks/__init__.py +1 -2
- mlrun/frameworks/pytorch/callbacks/callback.py +2 -1
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +15 -8
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +19 -12
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +22 -15
- mlrun/frameworks/pytorch/callbacks_handler.py +36 -30
- mlrun/frameworks/pytorch/mlrun_interface.py +17 -17
- mlrun/frameworks/pytorch/model_handler.py +21 -17
- mlrun/frameworks/pytorch/model_server.py +13 -9
- mlrun/frameworks/sklearn/__init__.py +19 -18
- mlrun/frameworks/sklearn/estimator.py +2 -2
- mlrun/frameworks/sklearn/metric.py +3 -3
- mlrun/frameworks/sklearn/metrics_library.py +8 -6
- mlrun/frameworks/sklearn/mlrun_interface.py +3 -2
- mlrun/frameworks/sklearn/model_handler.py +4 -3
- mlrun/frameworks/tf_keras/__init__.py +11 -12
- mlrun/frameworks/tf_keras/callbacks/__init__.py +1 -2
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +17 -14
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +15 -12
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +21 -18
- mlrun/frameworks/tf_keras/model_handler.py +17 -13
- mlrun/frameworks/tf_keras/model_server.py +12 -8
- mlrun/frameworks/xgboost/__init__.py +19 -18
- mlrun/frameworks/xgboost/model_handler.py +13 -9
- mlrun/k8s_utils.py +2 -5
- mlrun/launcher/base.py +3 -4
- mlrun/launcher/client.py +2 -2
- mlrun/launcher/local.py +6 -2
- mlrun/launcher/remote.py +1 -1
- mlrun/lists.py +8 -4
- mlrun/model.py +132 -46
- mlrun/model_monitoring/__init__.py +3 -5
- mlrun/model_monitoring/api.py +113 -98
- mlrun/model_monitoring/applications/__init__.py +0 -5
- mlrun/model_monitoring/applications/_application_steps.py +81 -50
- mlrun/model_monitoring/applications/base.py +467 -14
- mlrun/model_monitoring/applications/context.py +212 -134
- mlrun/model_monitoring/{db/stores/base → applications/evidently}/__init__.py +6 -2
- mlrun/model_monitoring/applications/evidently/base.py +146 -0
- mlrun/model_monitoring/applications/histogram_data_drift.py +89 -56
- mlrun/model_monitoring/applications/results.py +67 -15
- mlrun/model_monitoring/controller.py +701 -315
- mlrun/model_monitoring/db/__init__.py +0 -2
- mlrun/model_monitoring/db/_schedules.py +242 -0
- mlrun/model_monitoring/db/_stats.py +189 -0
- mlrun/model_monitoring/db/tsdb/__init__.py +33 -22
- mlrun/model_monitoring/db/tsdb/base.py +243 -49
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +76 -36
- mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +33 -0
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connection.py +213 -0
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +534 -88
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +1 -0
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +436 -106
- mlrun/model_monitoring/helpers.py +356 -114
- mlrun/model_monitoring/stream_processing.py +190 -345
- mlrun/model_monitoring/tracking_policy.py +11 -4
- mlrun/model_monitoring/writer.py +49 -90
- mlrun/package/__init__.py +3 -6
- mlrun/package/context_handler.py +2 -2
- mlrun/package/packager.py +12 -9
- mlrun/package/packagers/__init__.py +0 -2
- mlrun/package/packagers/default_packager.py +14 -11
- mlrun/package/packagers/numpy_packagers.py +16 -7
- mlrun/package/packagers/pandas_packagers.py +18 -18
- mlrun/package/packagers/python_standard_library_packagers.py +25 -11
- mlrun/package/packagers_manager.py +35 -32
- mlrun/package/utils/__init__.py +0 -3
- mlrun/package/utils/_pickler.py +6 -6
- mlrun/platforms/__init__.py +47 -16
- mlrun/platforms/iguazio.py +4 -1
- mlrun/projects/operations.py +30 -30
- mlrun/projects/pipelines.py +116 -47
- mlrun/projects/project.py +1292 -329
- mlrun/render.py +5 -9
- mlrun/run.py +57 -14
- mlrun/runtimes/__init__.py +1 -3
- mlrun/runtimes/base.py +30 -22
- mlrun/runtimes/daskjob.py +9 -9
- mlrun/runtimes/databricks_job/databricks_runtime.py +6 -5
- mlrun/runtimes/function_reference.py +5 -2
- mlrun/runtimes/generators.py +3 -2
- mlrun/runtimes/kubejob.py +6 -7
- mlrun/runtimes/mounts.py +574 -0
- mlrun/runtimes/mpijob/__init__.py +0 -2
- mlrun/runtimes/mpijob/abstract.py +7 -6
- mlrun/runtimes/nuclio/api_gateway.py +7 -7
- mlrun/runtimes/nuclio/application/application.py +11 -13
- mlrun/runtimes/nuclio/application/reverse_proxy.go +66 -64
- mlrun/runtimes/nuclio/function.py +127 -70
- mlrun/runtimes/nuclio/serving.py +105 -37
- mlrun/runtimes/pod.py +159 -54
- mlrun/runtimes/remotesparkjob.py +3 -2
- mlrun/runtimes/sparkjob/__init__.py +0 -2
- mlrun/runtimes/sparkjob/spark3job.py +22 -12
- mlrun/runtimes/utils.py +7 -6
- mlrun/secrets.py +2 -2
- mlrun/serving/__init__.py +8 -0
- mlrun/serving/merger.py +7 -5
- mlrun/serving/remote.py +35 -22
- mlrun/serving/routers.py +186 -240
- mlrun/serving/server.py +41 -10
- mlrun/serving/states.py +432 -118
- mlrun/serving/utils.py +13 -2
- mlrun/serving/v1_serving.py +3 -2
- mlrun/serving/v2_serving.py +161 -203
- mlrun/track/__init__.py +1 -1
- mlrun/track/tracker.py +2 -2
- mlrun/track/trackers/mlflow_tracker.py +6 -5
- mlrun/utils/async_http.py +35 -22
- mlrun/utils/clones.py +7 -4
- mlrun/utils/helpers.py +511 -58
- mlrun/utils/logger.py +119 -13
- mlrun/utils/notifications/notification/__init__.py +22 -19
- mlrun/utils/notifications/notification/base.py +39 -15
- mlrun/utils/notifications/notification/console.py +6 -6
- mlrun/utils/notifications/notification/git.py +11 -11
- mlrun/utils/notifications/notification/ipython.py +10 -9
- mlrun/utils/notifications/notification/mail.py +176 -0
- mlrun/utils/notifications/notification/slack.py +16 -8
- mlrun/utils/notifications/notification/webhook.py +24 -8
- mlrun/utils/notifications/notification_pusher.py +191 -200
- mlrun/utils/regex.py +12 -2
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info}/METADATA +81 -54
- mlrun-1.8.0.dist-info/RECORD +351 -0
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info}/WHEEL +1 -1
- mlrun/model_monitoring/applications/evidently_base.py +0 -137
- mlrun/model_monitoring/db/stores/__init__.py +0 -136
- mlrun/model_monitoring/db/stores/base/store.py +0 -213
- mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +0 -71
- mlrun/model_monitoring/db/stores/sqldb/models/base.py +0 -190
- mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +0 -103
- mlrun/model_monitoring/db/stores/sqldb/models/sqlite.py +0 -40
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +0 -659
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +0 -726
- mlrun/model_monitoring/model_endpoint.py +0 -118
- mlrun-1.7.2rc3.dist-info/RECORD +0 -351
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info}/entry_points.txt +0 -0
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info/licenses}/LICENSE +0 -0
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info}/top_level.txt +0 -0
mlrun/datastore/targets.py
CHANGED
|
@@ -40,7 +40,7 @@ from mlrun.utils.helpers import to_parquet
|
|
|
40
40
|
from mlrun.utils.v3io_clients import get_frames_client
|
|
41
41
|
|
|
42
42
|
from .. import errors
|
|
43
|
-
from ..data_types import ValueType
|
|
43
|
+
from ..data_types import ValueType, is_spark_dataframe
|
|
44
44
|
from ..platforms.iguazio import parse_path, split_path
|
|
45
45
|
from .datastore_profile import datastore_profile_read
|
|
46
46
|
from .spark_utils import spark_session_update_hadoop_options
|
|
@@ -86,8 +86,10 @@ def generate_target_run_id():
|
|
|
86
86
|
|
|
87
87
|
|
|
88
88
|
def write_spark_dataframe_with_options(spark_options, df, mode, write_format=None):
|
|
89
|
+
# TODO: Replace with just df.sparkSession when Spark 3.2 support is dropped
|
|
90
|
+
spark_session = getattr(df, "sparkSession", None) or df.sql_ctx.sparkSession
|
|
89
91
|
non_hadoop_spark_options = spark_session_update_hadoop_options(
|
|
90
|
-
|
|
92
|
+
spark_session, spark_options
|
|
91
93
|
)
|
|
92
94
|
if write_format:
|
|
93
95
|
df.write.format(write_format).mode(mode).save(**non_hadoop_spark_options)
|
|
@@ -396,7 +398,7 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
396
398
|
self,
|
|
397
399
|
name: str = "",
|
|
398
400
|
path=None,
|
|
399
|
-
attributes: dict[str, str] = None,
|
|
401
|
+
attributes: Optional[dict[str, str]] = None,
|
|
400
402
|
after_step=None,
|
|
401
403
|
columns=None,
|
|
402
404
|
partitioned: bool = False,
|
|
@@ -405,8 +407,8 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
405
407
|
time_partitioning_granularity: Optional[str] = None,
|
|
406
408
|
max_events: Optional[int] = None,
|
|
407
409
|
flush_after_seconds: Optional[int] = None,
|
|
408
|
-
storage_options: dict[str, str] = None,
|
|
409
|
-
schema: dict[str, Any] = None,
|
|
410
|
+
storage_options: Optional[dict[str, str]] = None,
|
|
411
|
+
schema: Optional[dict[str, Any]] = None,
|
|
410
412
|
credentials_prefix=None,
|
|
411
413
|
):
|
|
412
414
|
super().__init__(
|
|
@@ -441,8 +443,8 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
441
443
|
self.credentials_prefix = credentials_prefix
|
|
442
444
|
if credentials_prefix:
|
|
443
445
|
warnings.warn(
|
|
444
|
-
"The 'credentials_prefix' parameter is deprecated and will be removed in "
|
|
445
|
-
"1.
|
|
446
|
+
"The 'credentials_prefix' parameter is deprecated in 1.7.0 and will be removed in "
|
|
447
|
+
"1.10.0. Please use datastore profiles instead.",
|
|
446
448
|
FutureWarning,
|
|
447
449
|
)
|
|
448
450
|
|
|
@@ -510,7 +512,7 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
510
512
|
chunk_id=0,
|
|
511
513
|
**kwargs,
|
|
512
514
|
) -> Optional[int]:
|
|
513
|
-
if
|
|
515
|
+
if is_spark_dataframe(df):
|
|
514
516
|
options = self.get_spark_options(key_column, timestamp_key)
|
|
515
517
|
options.update(kwargs)
|
|
516
518
|
df = self.prepare_spark_df(df, key_column, timestamp_key, options)
|
|
@@ -834,16 +836,16 @@ class ParquetTarget(BaseStoreTarget):
|
|
|
834
836
|
self,
|
|
835
837
|
name: str = "",
|
|
836
838
|
path=None,
|
|
837
|
-
attributes: dict[str, str] = None,
|
|
839
|
+
attributes: Optional[dict[str, str]] = None,
|
|
838
840
|
after_step=None,
|
|
839
841
|
columns=None,
|
|
840
|
-
partitioned: bool = None,
|
|
842
|
+
partitioned: Optional[bool] = None,
|
|
841
843
|
key_bucketing_number: Optional[int] = None,
|
|
842
844
|
partition_cols: Optional[list[str]] = None,
|
|
843
845
|
time_partitioning_granularity: Optional[str] = None,
|
|
844
846
|
max_events: Optional[int] = 10000,
|
|
845
847
|
flush_after_seconds: Optional[int] = 900,
|
|
846
|
-
storage_options: dict[str, str] = None,
|
|
848
|
+
storage_options: Optional[dict[str, str]] = None,
|
|
847
849
|
):
|
|
848
850
|
self.path = path
|
|
849
851
|
if partitioned is None:
|
|
@@ -1200,7 +1202,7 @@ class SnowflakeTarget(BaseStoreTarget):
|
|
|
1200
1202
|
self,
|
|
1201
1203
|
name: str = "",
|
|
1202
1204
|
path=None,
|
|
1203
|
-
attributes: dict[str, str] = None,
|
|
1205
|
+
attributes: Optional[dict[str, str]] = None,
|
|
1204
1206
|
after_step=None,
|
|
1205
1207
|
columns=None,
|
|
1206
1208
|
partitioned: bool = False,
|
|
@@ -1209,15 +1211,15 @@ class SnowflakeTarget(BaseStoreTarget):
|
|
|
1209
1211
|
time_partitioning_granularity: Optional[str] = None,
|
|
1210
1212
|
max_events: Optional[int] = None,
|
|
1211
1213
|
flush_after_seconds: Optional[int] = None,
|
|
1212
|
-
storage_options: dict[str, str] = None,
|
|
1213
|
-
schema: dict[str, Any] = None,
|
|
1214
|
+
storage_options: Optional[dict[str, str]] = None,
|
|
1215
|
+
schema: Optional[dict[str, Any]] = None,
|
|
1214
1216
|
credentials_prefix=None,
|
|
1215
|
-
url: str = None,
|
|
1216
|
-
user: str = None,
|
|
1217
|
-
db_schema: str = None,
|
|
1218
|
-
database: str = None,
|
|
1219
|
-
warehouse: str = None,
|
|
1220
|
-
table_name: str = None,
|
|
1217
|
+
url: Optional[str] = None,
|
|
1218
|
+
user: Optional[str] = None,
|
|
1219
|
+
db_schema: Optional[str] = None,
|
|
1220
|
+
database: Optional[str] = None,
|
|
1221
|
+
warehouse: Optional[str] = None,
|
|
1222
|
+
table_name: Optional[str] = None,
|
|
1221
1223
|
):
|
|
1222
1224
|
attributes = attributes or {}
|
|
1223
1225
|
if url:
|
|
@@ -1376,7 +1378,7 @@ class NoSqlBaseTarget(BaseStoreTarget):
|
|
|
1376
1378
|
def write_dataframe(
|
|
1377
1379
|
self, df, key_column=None, timestamp_key=None, chunk_id=0, **kwargs
|
|
1378
1380
|
):
|
|
1379
|
-
if
|
|
1381
|
+
if is_spark_dataframe(df):
|
|
1380
1382
|
options = self.get_spark_options(key_column, timestamp_key)
|
|
1381
1383
|
options.update(kwargs)
|
|
1382
1384
|
df = self.prepare_spark_df(df)
|
|
@@ -1669,7 +1671,7 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1669
1671
|
):
|
|
1670
1672
|
attrs = {}
|
|
1671
1673
|
|
|
1672
|
-
# TODO: Remove this in 1.
|
|
1674
|
+
# TODO: Remove this in 1.10.0
|
|
1673
1675
|
if bootstrap_servers:
|
|
1674
1676
|
if brokers:
|
|
1675
1677
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
@@ -1677,7 +1679,7 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1677
1679
|
"'bootstrap_servers' parameter. Please use 'brokers' only."
|
|
1678
1680
|
)
|
|
1679
1681
|
warnings.warn(
|
|
1680
|
-
"'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.
|
|
1682
|
+
"'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.10.0, "
|
|
1681
1683
|
"use 'brokers' instead.",
|
|
1682
1684
|
FutureWarning,
|
|
1683
1685
|
)
|
|
@@ -1708,6 +1710,11 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1708
1710
|
if not path:
|
|
1709
1711
|
raise mlrun.errors.MLRunInvalidArgumentError("KafkaTarget requires a path")
|
|
1710
1712
|
|
|
1713
|
+
# Filter attributes to keep only Kafka-related parameters
|
|
1714
|
+
# This removes any non-Kafka parameters inherited from BaseStoreTarget
|
|
1715
|
+
attributes = mlrun.datastore.utils.KafkaParameters().valid_entries_only(
|
|
1716
|
+
self.attributes
|
|
1717
|
+
)
|
|
1711
1718
|
graph.add_step(
|
|
1712
1719
|
name=self.name or "KafkaTarget",
|
|
1713
1720
|
after=after,
|
|
@@ -1715,7 +1722,7 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1715
1722
|
class_name="mlrun.datastore.storeytargets.KafkaStoreyTarget",
|
|
1716
1723
|
columns=column_list,
|
|
1717
1724
|
path=path,
|
|
1718
|
-
attributes=
|
|
1725
|
+
attributes=attributes,
|
|
1719
1726
|
)
|
|
1720
1727
|
|
|
1721
1728
|
def purge(self):
|
|
@@ -1904,7 +1911,7 @@ class SQLTarget(BaseStoreTarget):
|
|
|
1904
1911
|
self,
|
|
1905
1912
|
name: str = "",
|
|
1906
1913
|
path=None,
|
|
1907
|
-
attributes: dict[str, str] = None,
|
|
1914
|
+
attributes: Optional[dict[str, str]] = None,
|
|
1908
1915
|
after_step=None,
|
|
1909
1916
|
partitioned: bool = False,
|
|
1910
1917
|
key_bucketing_number: Optional[int] = None,
|
|
@@ -1912,16 +1919,16 @@ class SQLTarget(BaseStoreTarget):
|
|
|
1912
1919
|
time_partitioning_granularity: Optional[str] = None,
|
|
1913
1920
|
max_events: Optional[int] = None,
|
|
1914
1921
|
flush_after_seconds: Optional[int] = None,
|
|
1915
|
-
storage_options: dict[str, str] = None,
|
|
1916
|
-
db_url: str = None,
|
|
1917
|
-
table_name: str = None,
|
|
1918
|
-
schema: dict[str, Any] = None,
|
|
1922
|
+
storage_options: Optional[dict[str, str]] = None,
|
|
1923
|
+
db_url: Optional[str] = None,
|
|
1924
|
+
table_name: Optional[str] = None,
|
|
1925
|
+
schema: Optional[dict[str, Any]] = None,
|
|
1919
1926
|
primary_key_column: str = "",
|
|
1920
1927
|
if_exists: str = "append",
|
|
1921
1928
|
create_table: bool = False,
|
|
1922
1929
|
# create_according_to_data: bool = False,
|
|
1923
1930
|
varchar_len: int = 50,
|
|
1924
|
-
parse_dates: list[str] = None,
|
|
1931
|
+
parse_dates: Optional[list[str]] = None,
|
|
1925
1932
|
):
|
|
1926
1933
|
"""
|
|
1927
1934
|
Write to SqlDB as output target for a flow.
|
|
@@ -2103,7 +2110,7 @@ class SQLTarget(BaseStoreTarget):
|
|
|
2103
2110
|
|
|
2104
2111
|
self._create_sql_table()
|
|
2105
2112
|
|
|
2106
|
-
if
|
|
2113
|
+
if is_spark_dataframe(df):
|
|
2107
2114
|
raise ValueError("Spark is not supported")
|
|
2108
2115
|
else:
|
|
2109
2116
|
(
|
mlrun/datastore/utils.py
CHANGED
|
@@ -26,7 +26,7 @@ import mlrun.datastore
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def parse_kafka_url(
|
|
29
|
-
url: str, brokers: typing.Union[list, str] = None
|
|
29
|
+
url: str, brokers: typing.Optional[typing.Union[list, str]] = None
|
|
30
30
|
) -> tuple[str, list]:
|
|
31
31
|
"""Generating Kafka topic and adjusting a list of bootstrap servers.
|
|
32
32
|
|
|
@@ -71,7 +71,7 @@ def upload_tarball(source_dir, target, secrets=None):
|
|
|
71
71
|
|
|
72
72
|
def filter_df_start_end_time(
|
|
73
73
|
df: typing.Union[pd.DataFrame, typing.Iterator[pd.DataFrame]],
|
|
74
|
-
time_column: str = None,
|
|
74
|
+
time_column: typing.Optional[str] = None,
|
|
75
75
|
start_time: pd.Timestamp = None,
|
|
76
76
|
end_time: pd.Timestamp = None,
|
|
77
77
|
) -> typing.Union[pd.DataFrame, typing.Iterator[pd.DataFrame]]:
|
|
@@ -176,8 +176,8 @@ def get_kafka_brokers_from_dict(options: dict, pop=False) -> typing.Optional[str
|
|
|
176
176
|
kafka_bootstrap_servers = get_or_pop("kafka_bootstrap_servers", None)
|
|
177
177
|
if kafka_bootstrap_servers:
|
|
178
178
|
warnings.warn(
|
|
179
|
-
"The 'kafka_bootstrap_servers' parameter is deprecated and will be removed in "
|
|
180
|
-
"1.
|
|
179
|
+
"The 'kafka_bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in "
|
|
180
|
+
"1.10.0. Please pass the 'kafka_brokers' parameter instead.",
|
|
181
181
|
FutureWarning,
|
|
182
182
|
)
|
|
183
183
|
return kafka_bootstrap_servers
|
|
@@ -222,3 +222,86 @@ def validate_additional_filters(additional_filters):
|
|
|
222
222
|
for sub_value in value:
|
|
223
223
|
if isinstance(sub_value, float) and math.isnan(sub_value):
|
|
224
224
|
raise mlrun.errors.MLRunInvalidArgumentError(nan_error_message)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class KafkaParameters:
|
|
228
|
+
def __init__(self, kwargs: typing.Optional[dict] = None):
|
|
229
|
+
import kafka
|
|
230
|
+
|
|
231
|
+
if kwargs is None:
|
|
232
|
+
kwargs = {}
|
|
233
|
+
self._kafka = kafka
|
|
234
|
+
self._kwargs = kwargs
|
|
235
|
+
self._client_configs = {
|
|
236
|
+
"consumer": self._kafka.KafkaConsumer.DEFAULT_CONFIG,
|
|
237
|
+
"producer": self._kafka.KafkaProducer.DEFAULT_CONFIG,
|
|
238
|
+
"admin": self._kafka.KafkaAdminClient.DEFAULT_CONFIG,
|
|
239
|
+
}
|
|
240
|
+
self._custom_attributes = {
|
|
241
|
+
"max_workers": "",
|
|
242
|
+
"brokers": "",
|
|
243
|
+
"topics": "",
|
|
244
|
+
"group": "",
|
|
245
|
+
"initial_offset": "",
|
|
246
|
+
"partitions": "",
|
|
247
|
+
"sasl": "",
|
|
248
|
+
"worker_allocation_mode": "",
|
|
249
|
+
}
|
|
250
|
+
self._reference_dicts = (
|
|
251
|
+
self._custom_attributes,
|
|
252
|
+
self._kafka.KafkaAdminClient.DEFAULT_CONFIG,
|
|
253
|
+
self._kafka.KafkaProducer.DEFAULT_CONFIG,
|
|
254
|
+
self._kafka.KafkaConsumer.DEFAULT_CONFIG,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
self._validate_keys()
|
|
258
|
+
|
|
259
|
+
def _validate_keys(self) -> None:
|
|
260
|
+
for key in self._kwargs:
|
|
261
|
+
if all(key not in d for d in self._reference_dicts):
|
|
262
|
+
raise ValueError(
|
|
263
|
+
f"Key '{key}' not found in any of the Kafka reference dictionaries"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
def _get_config(self, client_type: str) -> dict:
|
|
267
|
+
res = {
|
|
268
|
+
k: self._kwargs[k]
|
|
269
|
+
for k in self._kwargs.keys() & self._client_configs[client_type].keys()
|
|
270
|
+
}
|
|
271
|
+
if sasl := self._kwargs.get("sasl"):
|
|
272
|
+
res |= {
|
|
273
|
+
"security_protocol": "SASL_PLAINTEXT",
|
|
274
|
+
"sasl_mechanism": sasl["mechanism"],
|
|
275
|
+
"sasl_plain_username": sasl["user"],
|
|
276
|
+
"sasl_plain_password": sasl["password"],
|
|
277
|
+
}
|
|
278
|
+
return res
|
|
279
|
+
|
|
280
|
+
def consumer(self) -> dict:
|
|
281
|
+
return self._get_config("consumer")
|
|
282
|
+
|
|
283
|
+
def producer(self) -> dict:
|
|
284
|
+
return self._get_config("producer")
|
|
285
|
+
|
|
286
|
+
def admin(self) -> dict:
|
|
287
|
+
return self._get_config("admin")
|
|
288
|
+
|
|
289
|
+
def sasl(
|
|
290
|
+
self, *, usr: typing.Optional[str] = None, pwd: typing.Optional[str] = None
|
|
291
|
+
) -> dict:
|
|
292
|
+
usr = usr or self._kwargs.get("sasl_plain_username", None)
|
|
293
|
+
pwd = pwd or self._kwargs.get("sasl_plain_password", None)
|
|
294
|
+
res = self._kwargs.get("sasl", {})
|
|
295
|
+
if usr and pwd:
|
|
296
|
+
res["enable"] = True
|
|
297
|
+
res["user"] = usr
|
|
298
|
+
res["password"] = pwd
|
|
299
|
+
res["mechanism"] = self._kwargs.get("sasl_mechanism", "PLAIN")
|
|
300
|
+
return res
|
|
301
|
+
|
|
302
|
+
def valid_entries_only(self, input_dict: dict) -> dict:
|
|
303
|
+
valid_keys = set()
|
|
304
|
+
for ref_dict in self._reference_dicts:
|
|
305
|
+
valid_keys.update(ref_dict.keys())
|
|
306
|
+
# Return a new dictionary with only valid keys
|
|
307
|
+
return {k: v for k, v in input_dict.items() if k in valid_keys}
|
mlrun/datastore/v3io.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import time
|
|
16
16
|
from datetime import datetime
|
|
17
|
+
from typing import Optional
|
|
17
18
|
|
|
18
19
|
import fsspec
|
|
19
20
|
import v3io
|
|
@@ -33,7 +34,9 @@ V3IO_DEFAULT_UPLOAD_CHUNK_SIZE = 1024 * 1024 * 10
|
|
|
33
34
|
|
|
34
35
|
|
|
35
36
|
class V3ioStore(DataStore):
|
|
36
|
-
def __init__(
|
|
37
|
+
def __init__(
|
|
38
|
+
self, parent, schema, name, endpoint="", secrets: Optional[dict] = None
|
|
39
|
+
):
|
|
37
40
|
super().__init__(parent, name, schema, endpoint, secrets=secrets)
|
|
38
41
|
self.endpoint = self.endpoint or mlrun.mlconf.v3io_api
|
|
39
42
|
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
# Copyright 2024 Iguazio
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from collections.abc import Iterable
|
|
17
|
+
from typing import Optional, Union
|
|
18
|
+
|
|
19
|
+
from mlrun.artifacts import DocumentArtifact
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def find_existing_attribute(obj, base_name="name", parent_name="collection"):
|
|
23
|
+
# Define all possible patterns
|
|
24
|
+
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _extract_collection_name(vectorstore: "VectorStore") -> str: # noqa: F821
|
|
29
|
+
patterns = [
|
|
30
|
+
"collection.name",
|
|
31
|
+
"collection._name",
|
|
32
|
+
"_collection.name",
|
|
33
|
+
"_collection._name",
|
|
34
|
+
"collection_name",
|
|
35
|
+
"_collection_name",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
def resolve_attribute(obj, pattern):
|
|
39
|
+
if "." in pattern:
|
|
40
|
+
parts = pattern.split(".")
|
|
41
|
+
current = vectorstore
|
|
42
|
+
for part in parts:
|
|
43
|
+
if hasattr(current, part):
|
|
44
|
+
current = getattr(current, part)
|
|
45
|
+
else:
|
|
46
|
+
return None
|
|
47
|
+
return current
|
|
48
|
+
else:
|
|
49
|
+
return getattr(obj, pattern, None)
|
|
50
|
+
|
|
51
|
+
if type(vectorstore).__name__ == "PineconeVectorStore":
|
|
52
|
+
try:
|
|
53
|
+
url = (
|
|
54
|
+
vectorstore._index.config.host
|
|
55
|
+
if hasattr(vectorstore._index, "config")
|
|
56
|
+
else vectorstore._index._config.host
|
|
57
|
+
)
|
|
58
|
+
index_name = url.split("//")[1].split("-")[0]
|
|
59
|
+
return index_name
|
|
60
|
+
except Exception:
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
for pattern in patterns:
|
|
64
|
+
try:
|
|
65
|
+
value = resolve_attribute(vectorstore, pattern)
|
|
66
|
+
if value is not None:
|
|
67
|
+
return value
|
|
68
|
+
except (AttributeError, TypeError):
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
# If we get here, we couldn't find a valid collection name
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"Failed to extract collection name from the vector store. "
|
|
74
|
+
"Please provide the collection name explicitly. "
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class VectorStoreCollection:
|
|
79
|
+
"""
|
|
80
|
+
A wrapper class for vector store collections with MLRun integration.
|
|
81
|
+
|
|
82
|
+
This class wraps a vector store implementation (like Milvus, Chroma) and provides
|
|
83
|
+
integration with MLRun context for document and artifact management. It delegates
|
|
84
|
+
most operations to the underlying vector store while handling MLRun-specific
|
|
85
|
+
functionality.
|
|
86
|
+
|
|
87
|
+
The class implements attribute delegation through __getattr__ and __setattr__,
|
|
88
|
+
allowing direct access to the underlying vector store's methods and attributes
|
|
89
|
+
while maintaining MLRun integration.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
mlrun_context: Union["MlrunProject", "MLClientCtx"], # noqa: F821
|
|
95
|
+
vector_store: "VectorStore", # noqa: F821
|
|
96
|
+
collection_name: Optional[str] = None,
|
|
97
|
+
):
|
|
98
|
+
self._collection_impl = vector_store
|
|
99
|
+
self._mlrun_context = mlrun_context
|
|
100
|
+
self.collection_name = collection_name or _extract_collection_name(vector_store)
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def __class__(self):
|
|
104
|
+
# Make isinstance() check the wrapped object's class
|
|
105
|
+
return self._collection_impl.__class__
|
|
106
|
+
|
|
107
|
+
def __getattr__(self, name):
|
|
108
|
+
# This method is called when an attribute is not found in the usual places
|
|
109
|
+
# Forward the attribute access to _collection_impl
|
|
110
|
+
return getattr(self._collection_impl, name)
|
|
111
|
+
|
|
112
|
+
def __setattr__(self, name, value):
|
|
113
|
+
if name in ["_collection_impl", "_mlrun_context"] or name in self.__dict__:
|
|
114
|
+
# Use the base class method to avoid recursion
|
|
115
|
+
super().__setattr__(name, value)
|
|
116
|
+
else:
|
|
117
|
+
# Forward the attribute setting to _collection_impl
|
|
118
|
+
setattr(self._collection_impl, name, value)
|
|
119
|
+
|
|
120
|
+
def _get_mlrun_project_name(self):
|
|
121
|
+
import mlrun
|
|
122
|
+
|
|
123
|
+
if self._mlrun_context and isinstance(
|
|
124
|
+
self._mlrun_context, mlrun.projects.MlrunProject
|
|
125
|
+
):
|
|
126
|
+
return self._mlrun_context.name
|
|
127
|
+
if self._mlrun_context and isinstance(
|
|
128
|
+
self._mlrun_context, mlrun.execution.MLClientCtx
|
|
129
|
+
):
|
|
130
|
+
return self._mlrun_context.get_project_object().name
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
def delete(self, *args, **kwargs):
|
|
134
|
+
self._collection_impl.delete(*args, **kwargs)
|
|
135
|
+
|
|
136
|
+
def add_documents(
|
|
137
|
+
self,
|
|
138
|
+
documents: list["Document"], # noqa: F821
|
|
139
|
+
**kwargs,
|
|
140
|
+
):
|
|
141
|
+
"""
|
|
142
|
+
Add a list of documents to the collection.
|
|
143
|
+
|
|
144
|
+
If the instance has an MLRun context, it will update the MLRun artifacts
|
|
145
|
+
associated with the documents.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
documents (list[Document]): A list of Document objects to be added.
|
|
149
|
+
**kwargs: Additional keyword arguments to be passed to the underlying
|
|
150
|
+
collection implementation.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
The result of the underlying collection implementation's add_documents method.
|
|
154
|
+
"""
|
|
155
|
+
if self._mlrun_context:
|
|
156
|
+
for document in documents:
|
|
157
|
+
mlrun_key = document.metadata.get(
|
|
158
|
+
DocumentArtifact.METADATA_ARTIFACT_KEY, None
|
|
159
|
+
)
|
|
160
|
+
mlrun_project = document.metadata.get(
|
|
161
|
+
DocumentArtifact.METADATA_ARTIFACT_PROJECT, None
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if mlrun_key and mlrun_project == self._get_mlrun_project_name():
|
|
165
|
+
mlrun_tag = document.metadata.get(
|
|
166
|
+
DocumentArtifact.METADATA_ARTIFACT_TAG, None
|
|
167
|
+
)
|
|
168
|
+
artifact = self._mlrun_context.get_artifact(
|
|
169
|
+
key=mlrun_key, tag=mlrun_tag
|
|
170
|
+
)
|
|
171
|
+
if artifact.collection_add(self.collection_name):
|
|
172
|
+
self._mlrun_context.update_artifact(artifact)
|
|
173
|
+
|
|
174
|
+
return self._collection_impl.add_documents(documents, **kwargs)
|
|
175
|
+
|
|
176
|
+
def add_artifacts(self, artifacts: list[DocumentArtifact], splitter=None, **kwargs):
|
|
177
|
+
"""
|
|
178
|
+
Add a list of DocumentArtifact objects to the vector store collection.
|
|
179
|
+
|
|
180
|
+
Converts artifacts to LangChain documents, adds them to the vector store, and
|
|
181
|
+
updates the MLRun context. If documents are split, the IDs are handled appropriately.
|
|
182
|
+
|
|
183
|
+
:param artifacts: List of DocumentArtifact objects to add
|
|
184
|
+
:type artifacts: list[DocumentArtifact]
|
|
185
|
+
:param splitter: Document splitter to break artifacts into smaller chunks.
|
|
186
|
+
If None, each artifact becomes a single document.
|
|
187
|
+
:type splitter: TextSplitter, optional
|
|
188
|
+
:param kwargs: Additional arguments passed to the underlying add_documents method.
|
|
189
|
+
Special handling for 'ids' kwarg:
|
|
190
|
+
|
|
191
|
+
* If provided and document is split, IDs are generated as "{original_id}_{i}"
|
|
192
|
+
where i starts from 1 (e.g., "doc1_1", "doc1_2", etc.)
|
|
193
|
+
* If provided and document isn't split, original IDs are used as-is
|
|
194
|
+
|
|
195
|
+
:return: List of IDs for all added documents. When no custom IDs are provided:
|
|
196
|
+
|
|
197
|
+
* Without splitting: Vector store generates IDs automatically
|
|
198
|
+
* With splitting: Vector store generates separate IDs for each chunk
|
|
199
|
+
|
|
200
|
+
When custom IDs are provided:
|
|
201
|
+
|
|
202
|
+
* Without splitting: Uses provided IDs directly
|
|
203
|
+
* With splitting: Generates sequential IDs as "{original_id}_{i}" for each chunk
|
|
204
|
+
:rtype: list
|
|
205
|
+
|
|
206
|
+
"""
|
|
207
|
+
all_ids = []
|
|
208
|
+
user_ids = kwargs.pop("ids", None)
|
|
209
|
+
|
|
210
|
+
if user_ids:
|
|
211
|
+
if not isinstance(user_ids, Iterable):
|
|
212
|
+
raise ValueError("IDs must be an iterable collection")
|
|
213
|
+
if len(user_ids) != len(artifacts):
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"The number of IDs should match the number of artifacts"
|
|
216
|
+
)
|
|
217
|
+
for index, artifact in enumerate(artifacts):
|
|
218
|
+
documents = artifact.to_langchain_documents(splitter)
|
|
219
|
+
if artifact.collection_add(self.collection_name) and self._mlrun_context:
|
|
220
|
+
self._mlrun_context.update_artifact(artifact)
|
|
221
|
+
if user_ids:
|
|
222
|
+
num_of_documents = len(documents)
|
|
223
|
+
if num_of_documents > 1:
|
|
224
|
+
ids_to_pass = [
|
|
225
|
+
f"{user_ids[index]}_{i}" for i in range(1, num_of_documents + 1)
|
|
226
|
+
]
|
|
227
|
+
else:
|
|
228
|
+
ids_to_pass = [user_ids[index]]
|
|
229
|
+
kwargs["ids"] = ids_to_pass
|
|
230
|
+
ids = self._collection_impl.add_documents(documents, **kwargs)
|
|
231
|
+
all_ids.extend(ids)
|
|
232
|
+
return all_ids
|
|
233
|
+
|
|
234
|
+
def remove_from_artifact(self, artifact: DocumentArtifact):
|
|
235
|
+
"""
|
|
236
|
+
Remove the current object from the given artifact's collection and update the artifact.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
artifact (DocumentArtifact): The artifact from which the current object should be removed.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
if artifact.collection_remove(self.collection_name) and self._mlrun_context:
|
|
243
|
+
self._mlrun_context.update_artifact(artifact)
|
|
244
|
+
|
|
245
|
+
def delete_artifacts(self, artifacts: list[DocumentArtifact]):
|
|
246
|
+
"""
|
|
247
|
+
Delete a list of DocumentArtifact objects from the collection.
|
|
248
|
+
|
|
249
|
+
This method removes the specified artifacts from the collection and updates the MLRun context.
|
|
250
|
+
The deletion process varies depending on the type of the underlying collection implementation.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
artifacts (list[DocumentArtifact]): A list of DocumentArtifact objects to be deleted.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
NotImplementedError: If the delete operation is not supported for the collection implementation.
|
|
257
|
+
"""
|
|
258
|
+
store_class = self._collection_impl.__class__.__name__.lower()
|
|
259
|
+
for artifact in artifacts:
|
|
260
|
+
if artifact.collection_remove(self.collection_name) and self._mlrun_context:
|
|
261
|
+
self._mlrun_context.update_artifact(artifact)
|
|
262
|
+
|
|
263
|
+
if store_class == "milvus":
|
|
264
|
+
expr = f"{DocumentArtifact.METADATA_SOURCE_KEY} == '{artifact.get_source()}'"
|
|
265
|
+
self._collection_impl.delete(expr=expr)
|
|
266
|
+
elif store_class == "chroma":
|
|
267
|
+
where = {DocumentArtifact.METADATA_SOURCE_KEY: artifact.get_source()}
|
|
268
|
+
self._collection_impl.delete(where=where)
|
|
269
|
+
elif store_class == "pineconevectorstore":
|
|
270
|
+
filter = {
|
|
271
|
+
DocumentArtifact.METADATA_SOURCE_KEY: {"$eq": artifact.get_source()}
|
|
272
|
+
}
|
|
273
|
+
self._collection_impl.delete(filter=filter)
|
|
274
|
+
elif store_class == "mongodbatlasvectorsearch":
|
|
275
|
+
filter = {DocumentArtifact.METADATA_SOURCE_KEY: artifact.get_source()}
|
|
276
|
+
self._collection_impl.collection.delete_many(filter=filter)
|
|
277
|
+
elif (
|
|
278
|
+
hasattr(self._collection_impl, "delete")
|
|
279
|
+
and "filter"
|
|
280
|
+
in inspect.signature(self._collection_impl.delete).parameters
|
|
281
|
+
):
|
|
282
|
+
filter = {
|
|
283
|
+
"metadata": {
|
|
284
|
+
DocumentArtifact.METADATA_SOURCE_KEY: artifact.get_source()
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
self._collection_impl.delete(filter=filter)
|
|
288
|
+
else:
|
|
289
|
+
raise NotImplementedError(
|
|
290
|
+
f"delete_artifacts() operation not supported for {store_class}"
|
|
291
|
+
)
|
mlrun/datastore/wasbfs/fs.py
CHANGED
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Optional
|
|
15
16
|
from urllib.parse import urlparse
|
|
16
17
|
|
|
17
18
|
from fsspec import AbstractFileSystem
|
|
@@ -22,23 +23,23 @@ class WasbFS(AbstractFileSystem):
|
|
|
22
23
|
|
|
23
24
|
def __init__(
|
|
24
25
|
self,
|
|
25
|
-
account_name: str = None,
|
|
26
|
-
account_key: str = None,
|
|
27
|
-
connection_string: str = None,
|
|
28
|
-
credential: str = None,
|
|
29
|
-
sas_token: str = None,
|
|
26
|
+
account_name: Optional[str] = None,
|
|
27
|
+
account_key: Optional[str] = None,
|
|
28
|
+
connection_string: Optional[str] = None,
|
|
29
|
+
credential: Optional[str] = None,
|
|
30
|
+
sas_token: Optional[str] = None,
|
|
30
31
|
request_session=None,
|
|
31
|
-
socket_timeout: int = None,
|
|
32
|
-
blocksize: int = None,
|
|
33
|
-
client_id: str = None,
|
|
34
|
-
client_secret: str = None,
|
|
35
|
-
tenant_id: str = None,
|
|
32
|
+
socket_timeout: Optional[int] = None,
|
|
33
|
+
blocksize: Optional[int] = None,
|
|
34
|
+
client_id: Optional[str] = None,
|
|
35
|
+
client_secret: Optional[str] = None,
|
|
36
|
+
tenant_id: Optional[str] = None,
|
|
36
37
|
anon: bool = True,
|
|
37
|
-
location_mode: str = None,
|
|
38
|
+
location_mode: Optional[str] = None,
|
|
38
39
|
loop=None,
|
|
39
40
|
asynchronous: bool = False,
|
|
40
41
|
default_fill_cache: bool = True,
|
|
41
|
-
default_cache_type: str = None,
|
|
42
|
+
default_cache_type: Optional[str] = None,
|
|
42
43
|
**kwargs,
|
|
43
44
|
):
|
|
44
45
|
from adlfs import AzureBlobFileSystem
|