mlrun 1.7.0rc7__py3-none-any.whl → 1.7.0rc9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (52) hide show
  1. mlrun/__main__.py +2 -0
  2. mlrun/common/schemas/__init__.py +3 -0
  3. mlrun/common/schemas/api_gateway.py +8 -1
  4. mlrun/common/schemas/hub.py +7 -9
  5. mlrun/common/schemas/model_monitoring/constants.py +1 -1
  6. mlrun/common/schemas/pagination.py +26 -0
  7. mlrun/common/schemas/project.py +15 -10
  8. mlrun/config.py +28 -10
  9. mlrun/datastore/__init__.py +3 -7
  10. mlrun/datastore/datastore_profile.py +19 -1
  11. mlrun/datastore/snowflake_utils.py +43 -0
  12. mlrun/datastore/sources.py +9 -26
  13. mlrun/datastore/targets.py +131 -11
  14. mlrun/datastore/utils.py +10 -5
  15. mlrun/db/base.py +44 -0
  16. mlrun/db/httpdb.py +122 -21
  17. mlrun/db/nopdb.py +107 -0
  18. mlrun/feature_store/api.py +3 -2
  19. mlrun/feature_store/retrieval/spark_merger.py +27 -23
  20. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +1 -1
  21. mlrun/frameworks/tf_keras/mlrun_interface.py +2 -2
  22. mlrun/kfpops.py +2 -5
  23. mlrun/launcher/base.py +1 -1
  24. mlrun/launcher/client.py +2 -2
  25. mlrun/model_monitoring/helpers.py +3 -1
  26. mlrun/projects/pipelines.py +1 -1
  27. mlrun/projects/project.py +32 -21
  28. mlrun/run.py +5 -1
  29. mlrun/runtimes/__init__.py +16 -0
  30. mlrun/runtimes/base.py +4 -1
  31. mlrun/runtimes/kubejob.py +26 -121
  32. mlrun/runtimes/nuclio/api_gateway.py +58 -8
  33. mlrun/runtimes/nuclio/application/application.py +79 -1
  34. mlrun/runtimes/nuclio/application/reverse_proxy.go +9 -1
  35. mlrun/runtimes/nuclio/function.py +11 -8
  36. mlrun/runtimes/nuclio/serving.py +2 -2
  37. mlrun/runtimes/pod.py +145 -0
  38. mlrun/runtimes/utils.py +0 -28
  39. mlrun/serving/remote.py +2 -3
  40. mlrun/serving/routers.py +4 -3
  41. mlrun/serving/server.py +1 -1
  42. mlrun/serving/states.py +6 -9
  43. mlrun/serving/v2_serving.py +4 -3
  44. mlrun/utils/http.py +1 -1
  45. mlrun/utils/retryer.py +1 -0
  46. mlrun/utils/version/version.json +2 -2
  47. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/METADATA +15 -15
  48. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/RECORD +52 -50
  49. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/LICENSE +0 -0
  50. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/WHEEL +0 -0
  51. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/entry_points.txt +0 -0
  52. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/top_level.txt +0 -0
mlrun/__main__.py CHANGED
@@ -505,6 +505,8 @@ def build(
505
505
  if kfp:
506
506
  print("Runtime:")
507
507
  pprint(runtime)
508
+ # use kind = "job" by default if not specified
509
+ runtime.setdefault("kind", "job")
508
510
  func = new_function(runtime=runtime)
509
511
 
510
512
  elif func_url:
@@ -21,6 +21,7 @@ from .api_gateway import (
21
21
  APIGatewayMetadata,
22
22
  APIGatewaysOutput,
23
23
  APIGatewaySpec,
24
+ APIGatewayState,
24
25
  APIGatewayStatus,
25
26
  APIGatewayUpstream,
26
27
  )
@@ -151,12 +152,14 @@ from .notification import (
151
152
  SetNotificationRequest,
152
153
  )
153
154
  from .object import ObjectKind, ObjectMetadata, ObjectSpec, ObjectStatus
155
+ from .pagination import PaginationInfo
154
156
  from .pipeline import PipelinesFormat, PipelinesOutput, PipelinesPagination
155
157
  from .project import (
156
158
  IguazioProject,
157
159
  Project,
158
160
  ProjectDesiredState,
159
161
  ProjectMetadata,
162
+ ProjectOutput,
160
163
  ProjectOwner,
161
164
  ProjectsFormat,
162
165
  ProjectsOutput,
@@ -36,6 +36,13 @@ class APIGatewayAuthenticationMode(mlrun.common.types.StrEnum):
36
36
  )
37
37
 
38
38
 
39
+ class APIGatewayState(mlrun.common.types.StrEnum):
40
+ none = ""
41
+ ready = "ready"
42
+ error = "error"
43
+ waiting_for_provisioning = "waitingForProvisioning"
44
+
45
+
39
46
  class _APIGatewayBaseModel(pydantic.BaseModel):
40
47
  class Config:
41
48
  extra = pydantic.Extra.allow
@@ -72,7 +79,7 @@ class APIGatewaySpec(_APIGatewayBaseModel):
72
79
 
73
80
  class APIGatewayStatus(_APIGatewayBaseModel):
74
81
  name: Optional[str]
75
- state: Optional[str]
82
+ state: Optional[APIGatewayState]
76
83
 
77
84
 
78
85
  class APIGateway(_APIGatewayBaseModel):
@@ -59,28 +59,26 @@ class HubSource(BaseModel):
59
59
  return f"{self.spec.path}/{self.spec.object_type}/{self.spec.channel}/{relative_path}"
60
60
 
61
61
  def get_catalog_uri(self):
62
- return self.get_full_uri(mlrun.config.config.hub.catalog_filename)
62
+ return self.get_full_uri(mlrun.mlconf.hub.catalog_filename)
63
63
 
64
64
  @classmethod
65
65
  def generate_default_source(cls):
66
- if not mlrun.config.config.hub.default_source.create:
66
+ if not mlrun.mlconf.hub.default_source.create:
67
67
  return None
68
68
 
69
69
  now = datetime.now(timezone.utc)
70
70
  hub_metadata = HubObjectMetadata(
71
- name=mlrun.config.config.hub.default_source.name,
72
- description=mlrun.config.config.hub.default_source.description,
71
+ name=mlrun.mlconf.hub.default_source.name,
72
+ description=mlrun.mlconf.hub.default_source.description,
73
73
  created=now,
74
74
  updated=now,
75
75
  )
76
76
  return cls(
77
77
  metadata=hub_metadata,
78
78
  spec=HubSourceSpec(
79
- path=mlrun.config.config.hub.default_source.url,
80
- channel=mlrun.config.config.hub.default_source.channel,
81
- object_type=HubSourceType(
82
- mlrun.config.config.hub.default_source.object_type
83
- ),
79
+ path=mlrun.mlconf.hub.default_source.url,
80
+ channel=mlrun.mlconf.hub.default_source.channel,
81
+ object_type=HubSourceType(mlrun.mlconf.hub.default_source.object_type),
84
82
  ),
85
83
  status=ObjectStatus(state="created"),
86
84
  )
@@ -151,7 +151,7 @@ class ProjectSecretKeys:
151
151
  ENDPOINT_STORE_CONNECTION = "MODEL_MONITORING_ENDPOINT_STORE_CONNECTION"
152
152
  ACCESS_KEY = "MODEL_MONITORING_ACCESS_KEY"
153
153
  PIPELINES_ACCESS_KEY = "MODEL_MONITORING_PIPELINES_ACCESS_KEY"
154
- KAFKA_BOOTSTRAP_SERVERS = "KAFKA_BOOTSTRAP_SERVERS"
154
+ KAFKA_BROKERS = "KAFKA_BROKERS"
155
155
  STREAM_PATH = "STREAM_PATH"
156
156
 
157
157
 
@@ -0,0 +1,26 @@
1
+ # Copyright 2023 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import typing
16
+
17
+ import pydantic
18
+
19
+
20
+ class PaginationInfo(pydantic.BaseModel):
21
+ class Config:
22
+ allow_population_by_field_name = True
23
+
24
+ page: typing.Optional[int]
25
+ page_size: typing.Optional[int] = pydantic.Field(alias="page-size")
26
+ page_token: typing.Optional[str] = pydantic.Field(alias="page-token")
@@ -120,17 +120,22 @@ class IguazioProject(pydantic.BaseModel):
120
120
  data: dict
121
121
 
122
122
 
123
+ # The format query param controls the project type used:
124
+ # full - Project
125
+ # name_only - str
126
+ # summary - ProjectSummary
127
+ # leader - currently only IguazioProject supported
128
+ # The way pydantic handles typing.Union is that it takes the object and tries to coerce it to be the types of the
129
+ # union by the definition order. Therefore we can't currently add generic dict for all leader formats, but we need
130
+ # to add a specific classes for them. it's frustrating but couldn't find other workaround, see:
131
+ # https://github.com/samuelcolvin/pydantic/issues/1423, https://github.com/samuelcolvin/pydantic/issues/619
132
+ ProjectOutput = typing.TypeVar(
133
+ "ProjectOutput", Project, str, ProjectSummary, IguazioProject
134
+ )
135
+
136
+
123
137
  class ProjectsOutput(pydantic.BaseModel):
124
- # The format query param controls the project type used:
125
- # full - Project
126
- # name_only - str
127
- # summary - ProjectSummary
128
- # leader - currently only IguazioProject supported
129
- # The way pydantic handles typing.Union is that it takes the object and tries to coerce it to be the types of the
130
- # union by the definition order. Therefore we can't currently add generic dict for all leader formats, but we need
131
- # to add a specific classes for them. it's frustrating but couldn't find other workaround, see:
132
- # https://github.com/samuelcolvin/pydantic/issues/1423, https://github.com/samuelcolvin/pydantic/issues/619
133
- projects: list[typing.Union[Project, str, ProjectSummary, IguazioProject]]
138
+ projects: list[ProjectOutput]
134
139
 
135
140
 
136
141
  class ProjectSummariesOutput(pydantic.BaseModel):
mlrun/config.py CHANGED
@@ -240,6 +240,7 @@ default_config = {
240
240
  "remote": "mlrun/mlrun",
241
241
  "dask": "mlrun/ml-base",
242
242
  "mpijob": "mlrun/mlrun",
243
+ "application": "python:3.9-slim",
243
244
  },
244
245
  # see enrich_function_preemption_spec for more info,
245
246
  # and mlrun.common.schemas.function.PreemptionModes for available options
@@ -481,10 +482,13 @@ default_config = {
481
482
  # if set to true, will log a warning for trying to use run db functionality while in nop db mode
482
483
  "verbose": True,
483
484
  },
484
- "pagination_cache": {
485
- "interval": 60,
486
- "ttl": 3600,
487
- "max_size": 10000,
485
+ "pagination": {
486
+ "default_page_size": 20,
487
+ "pagination_cache": {
488
+ "interval": 60,
489
+ "ttl": 3600,
490
+ "max_size": 10000,
491
+ },
488
492
  },
489
493
  },
490
494
  "model_endpoint_monitoring": {
@@ -548,6 +552,7 @@ default_config = {
548
552
  "nosql": "v3io:///projects/{project}/FeatureStore/{name}/{kind}",
549
553
  # "authority" is optional and generalizes [userinfo "@"] host [":" port]
550
554
  "redisnosql": "redis://{authority}/projects/{project}/FeatureStore/{name}/{kind}",
555
+ "dsnosql": "ds://{ds_profile_name}/projects/{project}/FeatureStore/{name}/{kind}",
551
556
  },
552
557
  "default_targets": "parquet,nosql",
553
558
  "default_job_image": "mlrun/mlrun",
@@ -1073,7 +1078,7 @@ class Config:
1073
1078
  target: str = "online",
1074
1079
  artifact_path: str = None,
1075
1080
  function_name: str = None,
1076
- ) -> str:
1081
+ ) -> typing.Union[str, list[str]]:
1077
1082
  """Get the full path from the configuration based on the provided project and kind.
1078
1083
 
1079
1084
  :param project: Project name.
@@ -1089,7 +1094,8 @@ class Config:
1089
1094
  relative artifact path will be taken from the global MLRun artifact path.
1090
1095
  :param function_name: Application name, None for model_monitoring_stream.
1091
1096
 
1092
- :return: Full configured path for the provided kind.
1097
+ :return: Full configured path for the provided kind. Can be either a single path
1098
+ or a list of paths in the case of the online model monitoring stream path.
1093
1099
  """
1094
1100
 
1095
1101
  if target != "offline":
@@ -1111,10 +1117,22 @@ class Config:
1111
1117
  if function_name is None
1112
1118
  else f"{kind}-{function_name.lower()}",
1113
1119
  )
1114
- return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format(
1115
- project=project,
1116
- kind=kind,
1117
- )
1120
+ elif kind == "stream": # return list for mlrun<1.6.3 BC
1121
+ return [
1122
+ mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format(
1123
+ project=project,
1124
+ kind=kind,
1125
+ ), # old stream uri (pipelines) for BC ML-6043
1126
+ mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format(
1127
+ project=project,
1128
+ kind=kind,
1129
+ ), # new stream uri (projects)
1130
+ ]
1131
+ else:
1132
+ return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format(
1133
+ project=project,
1134
+ kind=kind,
1135
+ )
1118
1136
 
1119
1137
  # Get the current offline path from the configuration
1120
1138
  file_path = mlrun.mlconf.model_endpoint_monitoring.offline_storage_path.format(
@@ -107,13 +107,9 @@ def get_stream_pusher(stream_path: str, **kwargs):
107
107
  :param stream_path: path/url of stream
108
108
  """
109
109
 
110
- if stream_path.startswith("kafka://") or "kafka_bootstrap_servers" in kwargs:
111
- topic, bootstrap_servers = parse_kafka_url(
112
- stream_path, kwargs.get("kafka_bootstrap_servers")
113
- )
114
- return KafkaOutputStream(
115
- topic, bootstrap_servers, kwargs.get("kafka_producer_options")
116
- )
110
+ if stream_path.startswith("kafka://") or "kafka_brokers" in kwargs:
111
+ topic, brokers = parse_kafka_url(stream_path, kwargs.get("kafka_brokers"))
112
+ return KafkaOutputStream(topic, brokers, kwargs.get("kafka_producer_options"))
117
113
  elif stream_path.startswith("http://") or stream_path.startswith("https://"):
118
114
  return HTTPOutputStream(stream_path=stream_path)
119
115
  elif "://" not in stream_path:
@@ -16,6 +16,7 @@ import ast
16
16
  import base64
17
17
  import json
18
18
  import typing
19
+ import warnings
19
20
  from urllib.parse import ParseResult, urlparse, urlunparse
20
21
 
21
22
  import pydantic
@@ -68,6 +69,9 @@ class TemporaryClientDatastoreProfiles(metaclass=mlrun.utils.singleton.Singleton
68
69
  def get(self, key):
69
70
  return self._data.get(key, None)
70
71
 
72
+ def remove(self, key):
73
+ self._data.pop(key, None)
74
+
71
75
 
72
76
  class DatastoreProfileBasic(DatastoreProfile):
73
77
  type: str = pydantic.Field("basic")
@@ -80,12 +84,22 @@ class DatastoreProfileKafkaTarget(DatastoreProfile):
80
84
  type: str = pydantic.Field("kafka_target")
81
85
  _private_attributes = "kwargs_private"
82
86
  bootstrap_servers: str
87
+ brokers: str
83
88
  topic: str
84
89
  kwargs_public: typing.Optional[dict]
85
90
  kwargs_private: typing.Optional[dict]
86
91
 
92
+ def __pydantic_post_init__(self):
93
+ if self.bootstrap_servers:
94
+ warnings.warn(
95
+ "'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
96
+ "use 'brokers' instead.",
97
+ # TODO: Remove this in 1.9.0
98
+ FutureWarning,
99
+ )
100
+
87
101
  def attributes(self):
88
- attributes = {"bootstrap_servers": self.bootstrap_servers}
102
+ attributes = {"brokers": self.brokers or self.bootstrap_servers}
89
103
  if self.kwargs_public:
90
104
  attributes = merge(attributes, self.kwargs_public)
91
105
  if self.kwargs_private:
@@ -460,3 +474,7 @@ def register_temporary_client_datastore_profile(profile: DatastoreProfile):
460
474
  It's beneficial for testing purposes.
461
475
  """
462
476
  TemporaryClientDatastoreProfiles().add(profile)
477
+
478
+
479
+ def remove_temporary_client_datastore_profile(profile_name: str):
480
+ TemporaryClientDatastoreProfiles().remove(profile_name)
@@ -0,0 +1,43 @@
1
+ # Copyright 2024 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ import mlrun
17
+
18
+
19
+ def get_snowflake_password():
20
+ key = "SNOWFLAKE_PASSWORD"
21
+ snowflake_password = mlrun.get_secret_or_env(key)
22
+
23
+ if not snowflake_password:
24
+ raise mlrun.errors.MLRunInvalidArgumentError(
25
+ f"No password provided. Set password using the {key} "
26
+ "project secret or environment variable."
27
+ )
28
+
29
+ return snowflake_password
30
+
31
+
32
+ def get_snowflake_spark_options(attributes):
33
+ return {
34
+ "format": "net.snowflake.spark.snowflake",
35
+ "sfURL": attributes.get("url"),
36
+ "sfUser": attributes.get("user"),
37
+ "sfPassword": get_snowflake_password(),
38
+ "sfDatabase": attributes.get("database"),
39
+ "sfSchema": attributes.get("schema"),
40
+ "sfWarehouse": attributes.get("warehouse"),
41
+ "application": "iguazio_platform",
42
+ "TIMESTAMP_TYPE_MAPPING": "TIMESTAMP_LTZ",
43
+ }
@@ -28,6 +28,7 @@ from nuclio.config import split_path
28
28
 
29
29
  import mlrun
30
30
  from mlrun.config import config
31
+ from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
31
32
  from mlrun.secrets import SecretsStore
32
33
 
33
34
  from ..model import DataSource
@@ -113,7 +114,11 @@ class BaseSourceDriver(DataSource):
113
114
 
114
115
  def to_spark_df(self, session, named_view=False, time_field=None, columns=None):
115
116
  if self.support_spark:
116
- df = load_spark_dataframe_with_options(session, self.get_spark_options())
117
+ spark_options = self.get_spark_options()
118
+ spark_format = spark_options.pop("format", None)
119
+ df = load_spark_dataframe_with_options(
120
+ session, spark_options, format=spark_format
121
+ )
117
122
  if named_view:
118
123
  df.createOrReplaceTempView(self.name)
119
124
  return self._filter_spark_df(df, time_field, columns)
@@ -673,32 +678,10 @@ class SnowflakeSource(BaseSourceDriver):
673
678
  **kwargs,
674
679
  )
675
680
 
676
- def _get_password(self):
677
- key = "SNOWFLAKE_PASSWORD"
678
- snowflake_password = os.getenv(key) or os.getenv(
679
- SecretsStore.k8s_env_variable_name_for_secret(key)
680
- )
681
-
682
- if not snowflake_password:
683
- raise mlrun.errors.MLRunInvalidArgumentError(
684
- "No password provided. Set password using the SNOWFLAKE_PASSWORD "
685
- "project secret or environment variable."
686
- )
687
-
688
- return snowflake_password
689
-
690
681
  def get_spark_options(self):
691
- return {
692
- "format": "net.snowflake.spark.snowflake",
693
- "query": self.attributes.get("query"),
694
- "sfURL": self.attributes.get("url"),
695
- "sfUser": self.attributes.get("user"),
696
- "sfPassword": self._get_password(),
697
- "sfDatabase": self.attributes.get("database"),
698
- "sfSchema": self.attributes.get("schema"),
699
- "sfWarehouse": self.attributes.get("warehouse"),
700
- "application": "iguazio_platform",
701
- }
682
+ spark_options = get_snowflake_spark_options(self.attributes)
683
+ spark_options["query"] = self.attributes.get("query")
684
+ return spark_options
702
685
 
703
686
 
704
687
  class CustomSource(BaseSourceDriver):
@@ -17,6 +17,7 @@ import os
17
17
  import random
18
18
  import sys
19
19
  import time
20
+ import warnings
20
21
  from collections import Counter
21
22
  from copy import copy
22
23
  from typing import Any, Optional, Union
@@ -28,6 +29,7 @@ from mergedeep import merge
28
29
  import mlrun
29
30
  import mlrun.utils.helpers
30
31
  from mlrun.config import config
32
+ from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
31
33
  from mlrun.model import DataSource, DataTarget, DataTargetBase, TargetPathObject
32
34
  from mlrun.utils import logger, now_date
33
35
  from mlrun.utils.helpers import to_parquet
@@ -57,6 +59,7 @@ class TargetTypes:
57
59
  dataframe = "dataframe"
58
60
  custom = "custom"
59
61
  sql = "sql"
62
+ snowflake = "snowflake"
60
63
 
61
64
  @staticmethod
62
65
  def all():
@@ -71,6 +74,7 @@ class TargetTypes:
71
74
  TargetTypes.dataframe,
72
75
  TargetTypes.custom,
73
76
  TargetTypes.sql,
77
+ TargetTypes.snowflake,
74
78
  ]
75
79
 
76
80
 
@@ -78,11 +82,14 @@ def generate_target_run_id():
78
82
  return f"{round(time.time() * 1000)}_{random.randint(0, 999)}"
79
83
 
80
84
 
81
- def write_spark_dataframe_with_options(spark_options, df, mode):
85
+ def write_spark_dataframe_with_options(spark_options, df, mode, write_format=None):
82
86
  non_hadoop_spark_options = spark_session_update_hadoop_options(
83
87
  df.sql_ctx.sparkSession, spark_options
84
88
  )
85
- df.write.mode(mode).save(**non_hadoop_spark_options)
89
+ if write_format:
90
+ df.write.format(write_format).mode(mode).save(**non_hadoop_spark_options)
91
+ else:
92
+ df.write.mode(mode).save(**non_hadoop_spark_options)
86
93
 
87
94
 
88
95
  def default_target_names():
@@ -497,7 +504,10 @@ class BaseStoreTarget(DataTargetBase):
497
504
  options = self.get_spark_options(key_column, timestamp_key)
498
505
  options.update(kwargs)
499
506
  df = self.prepare_spark_df(df, key_column, timestamp_key, options)
500
- write_spark_dataframe_with_options(options, df, "overwrite")
507
+ write_format = options.pop("format", None)
508
+ write_spark_dataframe_with_options(
509
+ options, df, "overwrite", write_format=write_format
510
+ )
501
511
  elif hasattr(df, "dask"):
502
512
  dask_options = self.get_dask_options()
503
513
  store, path_in_store, target_path = self._get_store_and_path()
@@ -524,7 +534,12 @@ class BaseStoreTarget(DataTargetBase):
524
534
  store, path_in_store, target_path = self._get_store_and_path()
525
535
  target_path = generate_path_with_chunk(self, chunk_id, target_path)
526
536
  file_system = store.filesystem
527
- if file_system.protocol == "file":
537
+ if (
538
+ file_system.protocol == "file"
539
+ # fsspec 2023.10.0 changed protocol from "file" to ("file", "local")
540
+ or isinstance(file_system.protocol, (tuple, list))
541
+ and "file" in file_system.protocol
542
+ ):
528
543
  dir = os.path.dirname(target_path)
529
544
  if dir:
530
545
  os.makedirs(dir, exist_ok=True)
@@ -1108,6 +1123,97 @@ class CSVTarget(BaseStoreTarget):
1108
1123
  return True
1109
1124
 
1110
1125
 
1126
+ class SnowflakeTarget(BaseStoreTarget):
1127
+ """
1128
+ :param attributes: A dictionary of attributes for Snowflake connection; will be overridden by database parameters
1129
+ if they exist.
1130
+ :param url: Snowflake hostname, in the format: <account_name>.<region>.snowflakecomputing.com
1131
+ :param user: Snowflake user for login
1132
+ :param db_schema: Database schema
1133
+ :param database: Database name
1134
+ :param warehouse: Snowflake warehouse name
1135
+ :param table_name: Snowflake table name
1136
+ """
1137
+
1138
+ support_spark = True
1139
+ support_append = True
1140
+ is_offline = True
1141
+ kind = TargetTypes.snowflake
1142
+
1143
+ def __init__(
1144
+ self,
1145
+ name: str = "",
1146
+ path=None,
1147
+ attributes: dict[str, str] = None,
1148
+ after_step=None,
1149
+ columns=None,
1150
+ partitioned: bool = False,
1151
+ key_bucketing_number: Optional[int] = None,
1152
+ partition_cols: Optional[list[str]] = None,
1153
+ time_partitioning_granularity: Optional[str] = None,
1154
+ max_events: Optional[int] = None,
1155
+ flush_after_seconds: Optional[int] = None,
1156
+ storage_options: dict[str, str] = None,
1157
+ schema: dict[str, Any] = None,
1158
+ credentials_prefix=None,
1159
+ url: str = None,
1160
+ user: str = None,
1161
+ db_schema: str = None,
1162
+ database: str = None,
1163
+ warehouse: str = None,
1164
+ table_name: str = None,
1165
+ ):
1166
+ attrs = {
1167
+ "url": url,
1168
+ "user": user,
1169
+ "database": database,
1170
+ "schema": db_schema,
1171
+ "warehouse": warehouse,
1172
+ "table": table_name,
1173
+ }
1174
+ extended_attrs = {
1175
+ key: value for key, value in attrs.items() if value is not None
1176
+ }
1177
+ attributes = {} if not attributes else attributes
1178
+ attributes.update(extended_attrs)
1179
+ super().__init__(
1180
+ name,
1181
+ path,
1182
+ attributes,
1183
+ after_step,
1184
+ list(schema.keys()) if schema else columns,
1185
+ partitioned,
1186
+ key_bucketing_number,
1187
+ partition_cols,
1188
+ time_partitioning_granularity,
1189
+ max_events=max_events,
1190
+ flush_after_seconds=flush_after_seconds,
1191
+ storage_options=storage_options,
1192
+ schema=schema,
1193
+ credentials_prefix=credentials_prefix,
1194
+ )
1195
+
1196
+ def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True):
1197
+ spark_options = get_snowflake_spark_options(self.attributes)
1198
+ spark_options["dbtable"] = self.attributes.get("table")
1199
+ return spark_options
1200
+
1201
+ def purge(self):
1202
+ pass
1203
+
1204
+ def as_df(
1205
+ self,
1206
+ columns=None,
1207
+ df_module=None,
1208
+ entities=None,
1209
+ start_time=None,
1210
+ end_time=None,
1211
+ time_column=None,
1212
+ **kwargs,
1213
+ ):
1214
+ raise NotImplementedError()
1215
+
1216
+
1111
1217
  class NoSqlBaseTarget(BaseStoreTarget):
1112
1218
  is_table = True
1113
1219
  is_online = True
@@ -1179,7 +1285,10 @@ class NoSqlBaseTarget(BaseStoreTarget):
1179
1285
  options = self.get_spark_options(key_column, timestamp_key)
1180
1286
  options.update(kwargs)
1181
1287
  df = self.prepare_spark_df(df)
1182
- write_spark_dataframe_with_options(options, df, "overwrite")
1288
+ write_format = options.pop("format", None)
1289
+ write_spark_dataframe_with_options(
1290
+ options, df, "overwrite", write_format=write_format
1291
+ )
1183
1292
  else:
1184
1293
  # To prevent modification of the original dataframe and make sure
1185
1294
  # that the last event of a key is the one being persisted
@@ -1419,11 +1528,19 @@ class KafkaTarget(BaseStoreTarget):
1419
1528
  *args,
1420
1529
  bootstrap_servers=None,
1421
1530
  producer_options=None,
1531
+ brokers=None,
1422
1532
  **kwargs,
1423
1533
  ):
1424
1534
  attrs = {}
1535
+ if bootstrap_servers:
1536
+ warnings.warn(
1537
+ "'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
1538
+ "use 'brokers' instead.",
1539
+ # TODO: Remove this in 1.9.0
1540
+ FutureWarning,
1541
+ )
1425
1542
  if bootstrap_servers is not None:
1426
- attrs["bootstrap_servers"] = bootstrap_servers
1543
+ attrs["brokers"] = brokers or bootstrap_servers
1427
1544
  if producer_options is not None:
1428
1545
  attrs["producer_options"] = producer_options
1429
1546
 
@@ -1445,14 +1562,16 @@ class KafkaTarget(BaseStoreTarget):
1445
1562
  if self.path and self.path.startswith("ds://"):
1446
1563
  datastore_profile = datastore_profile_read(self.path)
1447
1564
  attributes = datastore_profile.attributes()
1448
- bootstrap_servers = attributes.pop("bootstrap_servers", None)
1565
+ brokers = attributes.pop(
1566
+ "brokers", attributes.pop("bootstrap_servers", None)
1567
+ )
1449
1568
  topic = datastore_profile.topic
1450
1569
  else:
1451
1570
  attributes = copy(self.attributes)
1452
- bootstrap_servers = attributes.pop("bootstrap_servers", None)
1453
- topic, bootstrap_servers = parse_kafka_url(
1454
- self.get_target_path(), bootstrap_servers
1571
+ brokers = attributes.pop(
1572
+ "brokers", attributes.pop("bootstrap_servers", None)
1455
1573
  )
1574
+ topic, brokers = parse_kafka_url(self.get_target_path(), brokers)
1456
1575
 
1457
1576
  if not topic:
1458
1577
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -1466,7 +1585,7 @@ class KafkaTarget(BaseStoreTarget):
1466
1585
  class_name="storey.KafkaTarget",
1467
1586
  columns=column_list,
1468
1587
  topic=topic,
1469
- bootstrap_servers=bootstrap_servers,
1588
+ brokers=brokers,
1470
1589
  **attributes,
1471
1590
  )
1472
1591
 
@@ -1957,6 +2076,7 @@ kind_to_driver = {
1957
2076
  TargetTypes.tsdb: TSDBTarget,
1958
2077
  TargetTypes.custom: CustomTarget,
1959
2078
  TargetTypes.sql: SQLTarget,
2079
+ TargetTypes.snowflake: SnowflakeTarget,
1960
2080
  }
1961
2081
 
1962
2082