mlrun 1.7.0rc3__py3-none-any.whl → 1.7.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (76) hide show
  1. mlrun/artifacts/manager.py +6 -1
  2. mlrun/common/constants.py +2 -0
  3. mlrun/common/model_monitoring/helpers.py +12 -6
  4. mlrun/common/schemas/__init__.py +11 -0
  5. mlrun/common/schemas/api_gateway.py +85 -0
  6. mlrun/common/schemas/auth.py +2 -2
  7. mlrun/common/schemas/client_spec.py +1 -0
  8. mlrun/common/schemas/common.py +40 -0
  9. mlrun/common/schemas/model_monitoring/constants.py +4 -1
  10. mlrun/common/schemas/project.py +2 -0
  11. mlrun/config.py +31 -17
  12. mlrun/datastore/azure_blob.py +22 -9
  13. mlrun/datastore/base.py +15 -25
  14. mlrun/datastore/datastore.py +19 -8
  15. mlrun/datastore/datastore_profile.py +47 -5
  16. mlrun/datastore/google_cloud_storage.py +10 -6
  17. mlrun/datastore/hdfs.py +51 -0
  18. mlrun/datastore/redis.py +4 -0
  19. mlrun/datastore/s3.py +4 -0
  20. mlrun/datastore/sources.py +29 -43
  21. mlrun/datastore/targets.py +59 -53
  22. mlrun/datastore/utils.py +2 -49
  23. mlrun/datastore/v3io.py +4 -0
  24. mlrun/db/base.py +50 -0
  25. mlrun/db/httpdb.py +121 -50
  26. mlrun/db/nopdb.py +13 -0
  27. mlrun/execution.py +3 -3
  28. mlrun/feature_store/feature_vector.py +2 -2
  29. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +3 -3
  30. mlrun/frameworks/tf_keras/model_handler.py +7 -7
  31. mlrun/k8s_utils.py +10 -5
  32. mlrun/kfpops.py +19 -10
  33. mlrun/model.py +5 -0
  34. mlrun/model_monitoring/api.py +3 -3
  35. mlrun/model_monitoring/application.py +1 -1
  36. mlrun/model_monitoring/applications/__init__.py +13 -0
  37. mlrun/model_monitoring/applications/histogram_data_drift.py +218 -0
  38. mlrun/model_monitoring/batch.py +9 -111
  39. mlrun/model_monitoring/controller.py +73 -55
  40. mlrun/model_monitoring/controller_handler.py +13 -5
  41. mlrun/model_monitoring/features_drift_table.py +62 -53
  42. mlrun/model_monitoring/helpers.py +30 -21
  43. mlrun/model_monitoring/metrics/__init__.py +13 -0
  44. mlrun/model_monitoring/metrics/histogram_distance.py +127 -0
  45. mlrun/model_monitoring/stores/kv_model_endpoint_store.py +14 -14
  46. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -1
  47. mlrun/package/packagers/pandas_packagers.py +3 -3
  48. mlrun/package/utils/_archiver.py +3 -1
  49. mlrun/platforms/iguazio.py +8 -65
  50. mlrun/projects/pipelines.py +21 -11
  51. mlrun/projects/project.py +180 -42
  52. mlrun/run.py +1 -1
  53. mlrun/runtimes/base.py +25 -2
  54. mlrun/runtimes/kubejob.py +5 -3
  55. mlrun/runtimes/local.py +2 -2
  56. mlrun/runtimes/mpijob/abstract.py +6 -6
  57. mlrun/runtimes/nuclio/__init__.py +1 -0
  58. mlrun/runtimes/nuclio/api_gateway.py +300 -0
  59. mlrun/runtimes/nuclio/function.py +9 -9
  60. mlrun/runtimes/nuclio/serving.py +3 -3
  61. mlrun/runtimes/pod.py +3 -3
  62. mlrun/runtimes/sparkjob/spark3job.py +3 -3
  63. mlrun/serving/remote.py +4 -2
  64. mlrun/serving/server.py +2 -8
  65. mlrun/utils/async_http.py +3 -3
  66. mlrun/utils/helpers.py +27 -5
  67. mlrun/utils/http.py +3 -3
  68. mlrun/utils/logger.py +2 -2
  69. mlrun/utils/notifications/notification_pusher.py +6 -6
  70. mlrun/utils/version/version.json +2 -2
  71. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc5.dist-info}/METADATA +13 -16
  72. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc5.dist-info}/RECORD +76 -68
  73. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc5.dist-info}/WHEEL +1 -1
  74. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc5.dist-info}/LICENSE +0 -0
  75. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc5.dist-info}/entry_points.txt +0 -0
  76. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc5.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,11 @@ from os.path import exists, isdir
17
17
  from urllib.parse import urlparse
18
18
 
19
19
  import mlrun.config
20
- from mlrun.utils.helpers import get_local_file_schema, template_artifact_path
20
+ from mlrun.utils.helpers import (
21
+ get_local_file_schema,
22
+ template_artifact_path,
23
+ validate_inline_artifact_body_size,
24
+ )
21
25
 
22
26
  from ..utils import (
23
27
  is_legacy_artifact,
@@ -212,6 +216,7 @@ class ArtifactManager:
212
216
  target_path = target_path or item.target_path
213
217
 
214
218
  validate_artifact_key_name(key, "artifact.key")
219
+ validate_inline_artifact_body_size(item.spec.inline)
215
220
  src_path = local_path or item.src_path # TODO: remove src_path
216
221
  self.ensure_artifact_source_file_exists(item=item, path=src_path, body=body)
217
222
  if format == "html" or (src_path and pathlib.Path(src_path).suffix == "html"):
mlrun/common/constants.py CHANGED
@@ -13,3 +13,5 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  IMAGE_NAME_ENRICH_REGISTRY_PREFIX = "." # prefix for image name to enrich with registry
16
+ MLRUN_CREATED_LABEL = "mlrun-created"
17
+ MYSQL_MEDIUMBLOB_SIZE_BYTES = 16 * 1024 * 1024
@@ -16,6 +16,7 @@ import sys
16
16
  import typing
17
17
 
18
18
  import mlrun.common
19
+ import mlrun.common.schemas.model_monitoring.constants as mm_constants
19
20
  from mlrun.common.schemas.model_monitoring import (
20
21
  EndpointUID,
21
22
  FunctionURI,
@@ -64,7 +65,7 @@ def parse_model_endpoint_store_prefix(store_prefix: str):
64
65
 
65
66
 
66
67
  def parse_monitoring_stream_path(
67
- stream_uri: str, project: str, application_name: str = None
68
+ stream_uri: str, project: str, function_name: str = None
68
69
  ):
69
70
  if stream_uri.startswith("kafka://"):
70
71
  if "?topic" in stream_uri:
@@ -72,23 +73,28 @@ def parse_monitoring_stream_path(
72
73
  "Custom kafka topic is not allowed"
73
74
  )
74
75
  # Add topic to stream kafka uri
75
- if application_name is None:
76
+ if (
77
+ function_name is None
78
+ or function_name == mm_constants.MonitoringFunctionNames.STREAM
79
+ ):
76
80
  stream_uri += f"?topic=monitoring_stream_{project}"
77
81
  else:
78
- stream_uri += f"?topic=monitoring_stream_{project}_{application_name}"
82
+ stream_uri += f"?topic=monitoring_stream_{project}_{function_name}"
79
83
 
80
84
  elif stream_uri.startswith("v3io://") and mlrun.mlconf.is_ce_mode():
81
85
  # V3IO is not supported in CE mode, generating a default http stream path
82
- if application_name is None:
86
+ if function_name is None:
83
87
  stream_uri = (
84
88
  mlrun.mlconf.model_endpoint_monitoring.default_http_sink.format(
85
- project=project
89
+ project=project, namespace=mlrun.mlconf.namespace
86
90
  )
87
91
  )
88
92
  else:
89
93
  stream_uri = (
90
94
  mlrun.mlconf.model_endpoint_monitoring.default_http_sink_app.format(
91
- project=project, application_name=application_name
95
+ project=project,
96
+ application_name=function_name,
97
+ namespace=mlrun.mlconf.namespace,
92
98
  )
93
99
  )
94
100
  return stream_uri
@@ -14,6 +14,16 @@
14
14
  #
15
15
  # flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
16
16
 
17
+ from .api_gateway import (
18
+ APIGateway,
19
+ APIGatewayAuthenticationMode,
20
+ APIGatewayBasicAuth,
21
+ APIGatewayMetadata,
22
+ APIGatewaysOutput,
23
+ APIGatewaySpec,
24
+ APIGatewayStatus,
25
+ APIGatewayUpstream,
26
+ )
17
27
  from .artifact import (
18
28
  Artifact,
19
29
  ArtifactCategories,
@@ -43,6 +53,7 @@ from .clusterization_spec import (
43
53
  ClusterizationSpec,
44
54
  WaitForChiefToReachOnlineStateFeatureFlag,
45
55
  )
56
+ from .common import ImageBuilder
46
57
  from .constants import (
47
58
  APIStates,
48
59
  ClusterizationRole,
@@ -0,0 +1,85 @@
1
+ # Copyright 2023 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ import typing
16
+ from typing import Optional
17
+
18
+ import pydantic
19
+
20
+ import mlrun.common.types
21
+
22
+
23
+ class APIGatewayAuthenticationMode(mlrun.common.types.StrEnum):
24
+ basic = "basicAuth"
25
+ none = "none"
26
+
27
+ @classmethod
28
+ def from_str(cls, authentication_mode: str):
29
+ if authentication_mode == "none":
30
+ return cls.none
31
+ elif authentication_mode == "basicAuth":
32
+ return cls.basic
33
+ else:
34
+ raise mlrun.errors.MLRunInvalidArgumentError(
35
+ f"Authentication mode `{authentication_mode}` is not supported",
36
+ )
37
+
38
+
39
+ class _APIGatewayBaseModel(pydantic.BaseModel):
40
+ class Config:
41
+ extra = pydantic.Extra.allow
42
+
43
+
44
+ class APIGatewayMetadata(_APIGatewayBaseModel):
45
+ name: str
46
+ namespace: Optional[str]
47
+ labels: Optional[dict] = {}
48
+
49
+
50
+ class APIGatewayBasicAuth(_APIGatewayBaseModel):
51
+ username: str
52
+ password: str
53
+
54
+
55
+ class APIGatewayUpstream(_APIGatewayBaseModel):
56
+ kind: Optional[str] = "nucliofunction"
57
+ nucliofunction: dict[str, str]
58
+ percentage: Optional[int] = 0
59
+
60
+
61
+ class APIGatewaySpec(_APIGatewayBaseModel):
62
+ name: str
63
+ description: Optional[str]
64
+ path: Optional[str] = "/"
65
+ authenticationMode: Optional[APIGatewayAuthenticationMode] = (
66
+ APIGatewayAuthenticationMode.none
67
+ )
68
+ upstreams: list[APIGatewayUpstream]
69
+ authentication: Optional[dict[str, Optional[APIGatewayBasicAuth]]]
70
+ host: Optional[str]
71
+
72
+
73
+ class APIGatewayStatus(_APIGatewayBaseModel):
74
+ name: Optional[str]
75
+ state: Optional[str]
76
+
77
+
78
+ class APIGateway(_APIGatewayBaseModel):
79
+ metadata: APIGatewayMetadata
80
+ spec: APIGatewaySpec
81
+ status: Optional[APIGatewayStatus]
82
+
83
+
84
+ class APIGatewaysOutput(_APIGatewayBaseModel):
85
+ api_gateways: typing.Optional[dict[str, APIGateway]] = {}
@@ -59,7 +59,7 @@ class AuthorizationResourceTypes(mlrun.common.types.StrEnum):
59
59
  hub_source = "hub-source"
60
60
  workflow = "workflow"
61
61
  datastore_profile = "datastore-profile"
62
- api_gateways = "api-gateways"
62
+ api_gateway = "api-gateway"
63
63
 
64
64
  def to_resource_string(
65
65
  self,
@@ -94,7 +94,7 @@ class AuthorizationResourceTypes(mlrun.common.types.StrEnum):
94
94
  AuthorizationResourceTypes.hub_source: "/marketplace/sources",
95
95
  # workflow define how to run a pipeline and can be considered as the specification of a pipeline.
96
96
  AuthorizationResourceTypes.workflow: "/projects/{project_name}/workflows/{resource_name}",
97
- AuthorizationResourceTypes.api_gateways: "/projects/{project_name}/api-gateways",
97
+ AuthorizationResourceTypes.api_gateway: "/projects/{project_name}/api-gateways/{resource_name}",
98
98
  }[self].format(project_name=project_name, resource_name=resource_name)
99
99
 
100
100
 
@@ -29,6 +29,7 @@ class ClientSpec(pydantic.BaseModel):
29
29
  ui_url: typing.Optional[str]
30
30
  artifact_path: typing.Optional[str]
31
31
  feature_store_data_prefixes: typing.Optional[dict[str, str]]
32
+ feature_store_default_targets: typing.Optional[str]
32
33
  spark_app_image: typing.Optional[str]
33
34
  spark_app_image_tag: typing.Optional[str]
34
35
  spark_history_server_path: typing.Optional[str]
@@ -0,0 +1,40 @@
1
+ # Copyright 2023 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ import typing
16
+
17
+ import pydantic
18
+
19
+
20
+ class ImageBuilder(pydantic.BaseModel):
21
+ functionSourceCode: typing.Optional[str] = None
22
+ codeEntryType: typing.Optional[str] = None
23
+ codeEntryAttributes: typing.Optional[str] = None
24
+ source: typing.Optional[str] = None
25
+ code_origin: typing.Optional[str] = None
26
+ origin_filename: typing.Optional[str] = None
27
+ image: typing.Optional[str] = None
28
+ base_image: typing.Optional[str] = None
29
+ commands: typing.Optional[list] = None
30
+ extra: typing.Optional[str] = None
31
+ extra_args: typing.Optional[dict] = None
32
+ builder_env: typing.Optional[dict] = None
33
+ secret: typing.Optional[str] = None
34
+ registry: typing.Optional[str] = None
35
+ load_source_on_run: typing.Optional[bool] = None
36
+ with_mlrun: typing.Optional[bool] = None
37
+ auto_build: typing.Optional[bool] = None
38
+ build_pod: typing.Optional[str] = None
39
+ requirements: typing.Optional[list] = None
40
+ source_code_target_dir: typing.Optional[str] = None
@@ -181,7 +181,7 @@ class MonitoringFunctionNames:
181
181
  WRITER = "model-monitoring-writer"
182
182
  BATCH = "model-monitoring-batch"
183
183
  APPLICATION_CONTROLLER = "model-monitoring-controller"
184
- STREAM = None
184
+ STREAM = "model-monitoring-stream"
185
185
 
186
186
  @staticmethod
187
187
  def all():
@@ -289,3 +289,6 @@ class ModelMonitoringAppLabel:
289
289
 
290
290
  class ControllerPolicy:
291
291
  BASE_PERIOD = "base_period"
292
+
293
+
294
+ MLRUN_HISTOGRAM_DATA_DRIFT_APP_NAME = "histogram-data-drift"
@@ -19,6 +19,7 @@ import pydantic
19
19
 
20
20
  import mlrun.common.types
21
21
 
22
+ from .common import ImageBuilder
22
23
  from .object import ObjectKind, ObjectStatus
23
24
 
24
25
 
@@ -85,6 +86,7 @@ class ProjectSpec(pydantic.BaseModel):
85
86
  desired_state: typing.Optional[ProjectDesiredState] = ProjectDesiredState.online
86
87
  custom_packagers: typing.Optional[list[tuple[str, bool]]] = None
87
88
  default_image: typing.Optional[str] = None
89
+ build: typing.Optional[ImageBuilder] = None
88
90
 
89
91
  class Config:
90
92
  extra = pydantic.Extra.allow
mlrun/config.py CHANGED
@@ -287,6 +287,12 @@ default_config = {
287
287
  "state": "online",
288
288
  "retry_api_call_on_exception": "enabled",
289
289
  "http_connection_timeout_keep_alive": 11,
290
+ # http client used by httpdb
291
+ "http": {
292
+ # when True, the client will verify the server's TLS
293
+ # set to False for backwards compatibility.
294
+ "verify": False,
295
+ },
290
296
  "db": {
291
297
  "commit_retry_timeout": 30,
292
298
  "commit_retry_interval": 3,
@@ -484,8 +490,8 @@ default_config = {
484
490
  "offline_storage_path": "model-endpoints/{kind}",
485
491
  # Default http path that points to the monitoring stream nuclio function. Will be used as a stream path
486
492
  # when the user is working in CE environment and has not provided any stream path.
487
- "default_http_sink": "http://nuclio-{project}-model-monitoring-stream.mlrun.svc.cluster.local:8080",
488
- "default_http_sink_app": "http://nuclio-{project}-{application_name}.mlrun.svc.cluster.local:8080",
493
+ "default_http_sink": "http://nuclio-{project}-model-monitoring-stream.{namespace}.svc.cluster.local:8080",
494
+ "default_http_sink_app": "http://nuclio-{project}-{application_name}.{namespace}.svc.cluster.local:8080",
489
495
  "batch_processing_function_branch": "master",
490
496
  "parquet_batching_max_events": 10_000,
491
497
  "parquet_batching_timeout_secs": timedelta(minutes=1).total_seconds(),
@@ -605,7 +611,7 @@ default_config = {
605
611
  "workflows": {
606
612
  "default_workflow_runner_name": "workflow-runner-{}",
607
613
  # Default timeout seconds for retrieving workflow id after execution:
608
- "timeouts": {"local": 120, "kfp": 30, "remote": 30},
614
+ "timeouts": {"local": 120, "kfp": 30, "remote": 90},
609
615
  },
610
616
  "log_collector": {
611
617
  "address": "localhost:8282",
@@ -957,10 +963,10 @@ class Config:
957
963
  with_gpu = (
958
964
  with_gpu_requests if requirement == "requests" else with_gpu_limits
959
965
  )
960
- resources[
961
- requirement
962
- ] = self.get_default_function_pod_requirement_resources(
963
- requirement, with_gpu
966
+ resources[requirement] = (
967
+ self.get_default_function_pod_requirement_resources(
968
+ requirement, with_gpu
969
+ )
964
970
  )
965
971
  return resources
966
972
 
@@ -1053,7 +1059,7 @@ class Config:
1053
1059
  kind: str = "",
1054
1060
  target: str = "online",
1055
1061
  artifact_path: str = None,
1056
- application_name: str = None,
1062
+ function_name: str = None,
1057
1063
  ) -> str:
1058
1064
  """Get the full path from the configuration based on the provided project and kind.
1059
1065
 
@@ -1068,7 +1074,7 @@ class Config:
1068
1074
  artifact path instead.
1069
1075
  :param artifact_path: Optional artifact path that will be used as a relative path. If not provided, the
1070
1076
  relative artifact path will be taken from the global MLRun artifact path.
1071
- :param application_name: Application name, None for model_monitoring_stream.
1077
+ :param function_name: Application name, None for model_monitoring_stream.
1072
1078
 
1073
1079
  :return: Full configured path for the provided kind.
1074
1080
  """
@@ -1082,20 +1088,19 @@ class Config:
1082
1088
  return store_prefix_dict[kind].format(project=project)
1083
1089
 
1084
1090
  if (
1085
- application_name
1091
+ function_name
1092
+ and function_name
1086
1093
  != mlrun.common.schemas.model_monitoring.constants.MonitoringFunctionNames.STREAM
1087
1094
  ):
1088
1095
  return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format(
1089
1096
  project=project,
1090
1097
  kind=kind
1091
- if application_name is None
1092
- else f"{kind}-{application_name.lower()}",
1098
+ if function_name is None
1099
+ else f"{kind}-{function_name.lower()}",
1093
1100
  )
1094
1101
  return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format(
1095
1102
  project=project,
1096
- kind=kind
1097
- if application_name is None
1098
- else f"{kind}-{application_name.lower()}",
1103
+ kind=kind,
1099
1104
  )
1100
1105
 
1101
1106
  # Get the current offline path from the configuration
@@ -1343,12 +1348,21 @@ def read_env(env=None, prefix=env_prefix):
1343
1348
  if igz_domain:
1344
1349
  config["ui_url"] = f"https://mlrun-ui.{igz_domain}"
1345
1350
 
1346
- if config.get("log_level"):
1351
+ if log_level := config.get("log_level"):
1347
1352
  import mlrun.utils.logger
1348
1353
 
1349
1354
  # logger created (because of imports mess) before the config is loaded (in tests), therefore we're changing its
1350
1355
  # level manually
1351
- mlrun.utils.logger.set_logger_level(config["log_level"])
1356
+ mlrun.utils.logger.set_logger_level(log_level)
1357
+
1358
+ if log_formatter_name := config.get("log_formatter"):
1359
+ import mlrun.utils.logger
1360
+
1361
+ log_formatter = mlrun.utils.create_formatter_instance(
1362
+ mlrun.utils.FormatterKinds(log_formatter_name)
1363
+ )
1364
+ mlrun.utils.logger.get_handler("default").setFormatter(log_formatter)
1365
+
1352
1366
  # The default function pod resource values are of type str; however, when reading from environment variable numbers,
1353
1367
  # it converts them to type int if contains only number, so we want to convert them to str.
1354
1368
  _convert_resources_to_str(config)
@@ -175,9 +175,9 @@ class AzureBlobStore(DataStore):
175
175
 
176
176
  if "client_secret" in st or "client_id" in st or "tenant_id" in st:
177
177
  res[f"spark.hadoop.fs.azure.account.auth.type.{host}"] = "OAuth"
178
- res[
179
- f"spark.hadoop.fs.azure.account.oauth.provider.type.{host}"
180
- ] = "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider"
178
+ res[f"spark.hadoop.fs.azure.account.oauth.provider.type.{host}"] = (
179
+ "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider"
180
+ )
181
181
  if "client_id" in st:
182
182
  res[f"spark.hadoop.fs.azure.account.oauth2.client.id.{host}"] = st[
183
183
  "client_id"
@@ -188,14 +188,27 @@ class AzureBlobStore(DataStore):
188
188
  ]
189
189
  if "tenant_id" in st:
190
190
  tenant_id = st["tenant_id"]
191
- res[
192
- f"spark.hadoop.fs.azure.account.oauth2.client.endpoint.{host}"
193
- ] = f"https://login.microsoftonline.com/{tenant_id}/oauth2/token"
191
+ res[f"spark.hadoop.fs.azure.account.oauth2.client.endpoint.{host}"] = (
192
+ f"https://login.microsoftonline.com/{tenant_id}/oauth2/token"
193
+ )
194
194
 
195
195
  if "sas_token" in st:
196
196
  res[f"spark.hadoop.fs.azure.account.auth.type.{host}"] = "SAS"
197
- res[
198
- f"spark.hadoop.fs.azure.sas.token.provider.type.{host}"
199
- ] = "org.apache.hadoop.fs.azurebfs.sas.FixedSASTokenProvider"
197
+ res[f"spark.hadoop.fs.azure.sas.token.provider.type.{host}"] = (
198
+ "org.apache.hadoop.fs.azurebfs.sas.FixedSASTokenProvider"
199
+ )
200
200
  res[f"spark.hadoop.fs.azure.sas.fixed.token.{host}"] = st["sas_token"]
201
201
  return res
202
+
203
+ @property
204
+ def spark_url(self):
205
+ spark_options = self.get_spark_options()
206
+ url = f"wasbs://{self.endpoint}"
207
+ prefix = "spark.hadoop.fs.azure.account.key."
208
+ if spark_options:
209
+ for key in spark_options:
210
+ if key.startswith(prefix):
211
+ account_key = key[len(prefix) :]
212
+ url += f"@{account_key}"
213
+ break
214
+ return url
mlrun/datastore/base.py CHANGED
@@ -147,6 +147,10 @@ class DataStore:
147
147
  def url(self):
148
148
  return f"{self.kind}://{self.endpoint}"
149
149
 
150
+ @property
151
+ def spark_url(self):
152
+ return self.url
153
+
150
154
  def get(self, key, size=None, offset=0):
151
155
  pass
152
156
 
@@ -320,31 +324,17 @@ class DataStore:
320
324
  raise Exception(f"File type unhandled {url}")
321
325
 
322
326
  if file_system:
323
- if (
324
- self.supports_isdir()
325
- and file_system.isdir(file_url)
326
- or self._is_dd(df_module)
327
- ):
328
- storage_options = self.get_storage_options()
329
- if url.startswith("ds://"):
330
- parsed_url = urllib.parse.urlparse(url)
331
- url = parsed_url.path
332
- if self.using_bucket:
333
- url = url[1:]
334
- # Pass the underlying file system
335
- kwargs["filesystem"] = file_system
336
- elif storage_options:
337
- kwargs["storage_options"] = storage_options
338
- df = reader(url, **kwargs)
339
- else:
340
- file = url
341
- # Workaround for ARROW-12472 affecting pyarrow 3.x and 4.x.
342
- if file_system.protocol != "file":
343
- # If not dir, use file_system.open() to avoid regression when pandas < 1.2 and does not
344
- # support the storage_options parameter.
345
- file = file_system.open(url)
346
-
347
- df = reader(file, **kwargs)
327
+ storage_options = self.get_storage_options()
328
+ if url.startswith("ds://"):
329
+ parsed_url = urllib.parse.urlparse(url)
330
+ url = parsed_url.path
331
+ if self.using_bucket:
332
+ url = url[1:]
333
+ # Pass the underlying file system
334
+ kwargs["filesystem"] = file_system
335
+ elif storage_options:
336
+ kwargs["storage_options"] = storage_options
337
+ df = reader(url, **kwargs)
348
338
  else:
349
339
  temp_file = tempfile.NamedTemporaryFile(delete=False)
350
340
  self.download(self._join(subpath), temp_file.name)
@@ -94,6 +94,10 @@ def schema_to_store(schema):
94
94
  from .dbfs_store import DBFSStore
95
95
 
96
96
  return DBFSStore
97
+ elif schema == "hdfs":
98
+ from .hdfs import HdfsStore
99
+
100
+ return HdfsStore
97
101
  else:
98
102
  raise ValueError(f"unsupported store scheme ({schema})")
99
103
 
@@ -170,7 +174,7 @@ class StoreManager:
170
174
  raise mlrun.errors.MLRunInvalidArgumentError(
171
175
  f"resource {url} does not have a valid/persistent offline target"
172
176
  )
173
- return resource, target
177
+ return resource, target or ""
174
178
 
175
179
  def object(
176
180
  self, url, key="", project="", allow_empty_resources=None, secrets: dict = None
@@ -182,14 +186,21 @@ class StoreManager:
182
186
  url, project, allow_empty_resources, secrets
183
187
  )
184
188
 
185
- store, subpath = self.get_or_create_store(
189
+ store, subpath, url = self.get_or_create_store(
186
190
  url, secrets=secrets, project_name=project
187
191
  )
188
- return DataItem(key, store, subpath, url, meta=meta, artifact_url=artifact_url)
192
+ return DataItem(
193
+ key,
194
+ store,
195
+ subpath,
196
+ url,
197
+ meta=meta,
198
+ artifact_url=artifact_url,
199
+ )
189
200
 
190
201
  def get_or_create_store(
191
202
  self, url, secrets: dict = None, project_name=""
192
- ) -> (DataStore, str):
203
+ ) -> (DataStore, str, str):
193
204
  schema, endpoint, parsed_url = parse_url(url)
194
205
  subpath = parsed_url.path
195
206
  store_key = f"{schema}://{endpoint}"
@@ -206,17 +217,17 @@ class StoreManager:
206
217
 
207
218
  if schema == "memory":
208
219
  subpath = url[len("memory://") :]
209
- return in_memory_store, subpath
220
+ return in_memory_store, subpath, url
210
221
 
211
222
  if not schema and endpoint:
212
223
  if endpoint in self._stores.keys():
213
- return self._stores[endpoint], subpath
224
+ return self._stores[endpoint], subpath, url
214
225
  else:
215
226
  raise ValueError(f"no such store ({endpoint})")
216
227
 
217
228
  if not secrets and not mlrun.config.is_running_as_api():
218
229
  if store_key in self._stores.keys():
219
- return self._stores[store_key], subpath
230
+ return self._stores[store_key], subpath, url
220
231
 
221
232
  # support u/p embedding in url (as done in redis) by setting netloc as the "endpoint" parameter
222
233
  # when running on server we don't cache the datastore, because there are multiple users and we don't want to
@@ -227,7 +238,7 @@ class StoreManager:
227
238
  if not secrets and not mlrun.config.is_running_as_api():
228
239
  self._stores[store_key] = store
229
240
  # in file stores in windows path like c:\a\b the drive letter is dropped from the path, so we return the url
230
- return store, url if store.kind == "file" else subpath
241
+ return store, url if store.kind == "file" else subpath, url
231
242
 
232
243
  def reset_secrets(self):
233
244
  self._secrets = {}