genesis-flow 1.0.1__py3-none-any.whl → 1.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1286 @@
1
+ """
2
+ The ``mlflow.spark`` module provides an API for logging and loading Spark MLlib models. This module
3
+ exports Spark MLlib models with the following flavors:
4
+
5
+ Spark MLlib (native) format
6
+ Allows models to be loaded as Spark Transformers for scoring in a Spark session.
7
+ Models with this flavor can be loaded as PySpark PipelineModel objects in Python.
8
+ This is the main flavor and is always produced.
9
+ :py:mod:`mlflow.pyfunc`
10
+ Supports deployment outside of Spark by instantiating a SparkContext and reading
11
+ input data as a Spark DataFrame prior to scoring. Also supports deployment in Spark
12
+ as a Spark UDF. Models with this flavor can be loaded as Python functions
13
+ for performing inference. This flavor is always produced.
14
+ """
15
+
16
+ import logging
17
+ import os
18
+ import posixpath
19
+ import re
20
+ import shutil
21
+ from typing import Any, Optional
22
+ from urllib.parse import urlparse
23
+
24
+ import yaml
25
+ from packaging.version import Version
26
+
27
+ import mlflow
28
+ from mlflow import environment_variables, pyfunc
29
+ from mlflow.environment_variables import MLFLOW_DFS_TMP
30
+ from mlflow.exceptions import MlflowException
31
+ from mlflow.models import Model, ModelInputExample, ModelSignature
32
+ from mlflow.models.model import MLMODEL_FILE_NAME
33
+ from mlflow.models.signature import _LOG_MODEL_INFER_SIGNATURE_WARNING_TEMPLATE
34
+ from mlflow.models.utils import _Example, _save_example
35
+ from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
36
+ from mlflow.store.artifact.databricks_artifact_repo import DatabricksArtifactRepository
37
+ from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
38
+ from mlflow.tracking.artifact_utils import (
39
+ _download_artifact_from_uri,
40
+ _get_root_uri_and_artifact_path,
41
+ )
42
+ from mlflow.types.schema import SparkMLVector
43
+ from mlflow.utils import _get_fully_qualified_class_name, databricks_utils
44
+ from mlflow.utils.autologging_utils import autologging_integration, safe_patch
45
+ from mlflow.utils.class_utils import _get_class_from_string
46
+ from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
47
+ from mlflow.utils.environment import (
48
+ _CONDA_ENV_FILE_NAME,
49
+ _CONSTRAINTS_FILE_NAME,
50
+ _PYTHON_ENV_FILE_NAME,
51
+ _REQUIREMENTS_FILE_NAME,
52
+ _mlflow_conda_env,
53
+ _process_conda_env,
54
+ _process_pip_requirements,
55
+ _PythonEnv,
56
+ _validate_env_arguments,
57
+ )
58
+ from mlflow.utils.file_utils import (
59
+ TempDir,
60
+ get_total_file_size,
61
+ shutil_copytree_without_file_permissions,
62
+ write_to,
63
+ )
64
+ from mlflow.utils.model_utils import (
65
+ _add_code_from_conf_to_system_path,
66
+ _validate_and_copy_code_paths,
67
+ )
68
+ from mlflow.utils.requirements_utils import _get_pinned_requirement
69
+ from mlflow.utils.uri import (
70
+ append_to_uri_path,
71
+ dbfs_hdfs_uri_to_fuse_path,
72
+ generate_tmp_dfs_path,
73
+ get_databricks_profile_uri_from_artifact_uri,
74
+ is_databricks_acled_artifacts_uri,
75
+ is_local_uri,
76
+ is_valid_dbfs_uri,
77
+ )
78
+
79
+ FLAVOR_NAME = "spark"
80
+
81
+ _SPARK_MODEL_PATH_SUB = "sparkml"
82
+ _MLFLOWDBFS_SCHEME = "mlflowdbfs"
83
+
84
+
85
+ _logger = logging.getLogger(__name__)
86
+
87
+
88
+ def get_default_pip_requirements(is_spark_connect_model=False):
89
+ """
90
+ Returns:
91
+ A list of default pip requirements for MLflow Models produced by this flavor.
92
+ Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
93
+ that, at minimum, contains these requirements.
94
+ """
95
+ import pyspark
96
+
97
+ # Strip the suffix from `dev` versions of PySpark, which are not
98
+ # available for installation from Anaconda or PyPI
99
+ pyspark_req_str = "pyspark[connect]" if is_spark_connect_model else "pyspark"
100
+ pyspark_req = re.sub(r"(\.?)dev.*$", "", _get_pinned_requirement(pyspark_req_str))
101
+ reqs = [pyspark_req]
102
+ if Version(pyspark.__version__) < Version("3.4"):
103
+ # Versions of PySpark < 3.4 are incompatible with pandas >= 2
104
+ reqs.append("pandas<2")
105
+
106
+ if is_spark_connect_model:
107
+ reqs.extend(
108
+ [
109
+ # Spark connect ML Model uses spark torch distributor to train model
110
+ _get_pinned_requirement("torch"),
111
+ # Spark connect ML Model saves feature transformers as sklearn transformer format.
112
+ _get_pinned_requirement("scikit-learn", module="sklearn"),
113
+ # Spark connect ML evaluators depend on torcheval package.
114
+ _get_pinned_requirement("torcheval"),
115
+ ]
116
+ )
117
+ return reqs
118
+
119
+
120
+ def get_default_conda_env(is_spark_connect_model=False):
121
+ """
122
+ Returns:
123
+ The default Conda environment for MLflow Models produced by calls to
124
+ :func:`save_model()` and :func:`log_model()`. This Conda environment
125
+ contains the current version of PySpark that is installed on the caller's
126
+ system. ``dev`` versions of PySpark are replaced with stable versions in
127
+ the resulting Conda environment (e.g., if you are running PySpark version
128
+ ``2.4.5.dev0``, invoking this method produces a Conda environment with a
129
+ dependency on PySpark version ``2.4.5``).
130
+ """
131
+ return _mlflow_conda_env(
132
+ additional_pip_deps=get_default_pip_requirements(
133
+ is_spark_connect_model=is_spark_connect_model
134
+ )
135
+ )
136
+
137
+
138
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="pyspark"))
139
+ def log_model(
140
+ spark_model,
141
+ artifact_path,
142
+ conda_env=None,
143
+ code_paths=None,
144
+ dfs_tmpdir=None,
145
+ registered_model_name=None,
146
+ signature: ModelSignature = None,
147
+ input_example: ModelInputExample = None,
148
+ await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
149
+ pip_requirements=None,
150
+ extra_pip_requirements=None,
151
+ metadata=None,
152
+ ):
153
+ """
154
+ Log a Spark MLlib model as an MLflow artifact for the current run. This uses the
155
+ MLlib persistence format and produces an MLflow Model with the Spark flavor.
156
+
157
+ Note: If no run is active, it will instantiate a run to obtain a run_id.
158
+
159
+ Args:
160
+ spark_model: Spark model to be saved - MLflow can only save descendants of
161
+ pyspark.ml.Model or pyspark.ml.Transformer which implement
162
+ MLReadable and MLWritable.
163
+
164
+ .. Note:: The provided Spark model's `transform` method must generate one column
165
+ named with "prediction", the column is used as MLflow pyfunc model output.
166
+ Most Spark models generate the output column with "prediction" name that
167
+ contains prediction labels by default.
168
+ To set probability column as the output column for probabilistic
169
+ classification models, you need to set "probabilityCol" param to "prediction"
170
+ and set "predictionCol" param to "".
171
+ (e.g. `model.setProbabilityCol("prediction").setPredictionCol("")`)
172
+ artifact_path: Run relative artifact path.
173
+ conda_env: {{ conda_env }}
174
+ code_paths: {{ code_paths }}
175
+ dfs_tmpdir: Temporary directory path on Distributed (Hadoop) File System (DFS) or local
176
+ filesystem if running in local mode. The model is written in this
177
+ destination and then copied into the model's artifact directory. This is
178
+ necessary as Spark ML models read from and write to DFS if running on a
179
+ cluster. If this operation completes successfully, all temporary files
180
+ created on the DFS are removed. Defaults to ``/tmp/mlflow``.
181
+ For models defined in `pyspark.ml.connect` module, this param is ignored.
182
+ registered_model_name: If given, create a model version under
183
+ ``registered_model_name``, also creating a registered model if one
184
+ with the given name does not exist.
185
+ signature: A Model Signature object that describes the input and output Schema of the
186
+ model. The model signature can be inferred using `infer_signature` function
187
+ of `mlflow.models.signature`.
188
+ Note if your Spark model contains Spark ML vector type input or output column,
189
+ you should create ``SparkMLVector`` vector type for the column,
190
+ `infer_signature` function can also infer ``SparkMLVector`` vector type correctly
191
+ from Spark Dataframe input / output.
192
+ When loading a Spark ML model with ``SparkMLVector`` vector type input as MLflow
193
+ pyfunc model, it accepts ``Array[double]`` type input. MLflow internally converts
194
+ the array into Spark ML vector and then invoke Spark model for inference. Similarly,
195
+ if the model has vector type output, MLflow internally converts Spark ML vector
196
+ output data into ``Array[double]`` type inference result.
197
+
198
+ .. code-block:: python
199
+
200
+ from mlflow.models import infer_signature
201
+ from pyspark.sql.functions import col
202
+ from pyspark.ml.classification import LogisticRegression
203
+ from pyspark.ml.functions import array_to_vector
204
+ import pandas as pd
205
+ import mlflow
206
+
207
+ train_df = spark.createDataFrame(
208
+ [([3.0, 4.0], 0), ([5.0, 6.0], 1)], schema="features array<double>, label long"
209
+ ).select(array_to_vector("features").alias("features"), col("label"))
210
+ lor = LogisticRegression(maxIter=2)
211
+ lor.setPredictionCol("").setProbabilityCol("prediction")
212
+ lor_model = lor.fit(train_df)
213
+
214
+ test_df = train_df.select("features")
215
+ prediction_df = lor_model.transform(train_df).select("prediction")
216
+
217
+ signature = infer_signature(test_df, prediction_df)
218
+
219
+ with mlflow.start_run() as run:
220
+ model_info = mlflow.spark.log_model(
221
+ lor_model,
222
+ "model",
223
+ signature=signature,
224
+ )
225
+
226
+ # The following signature is outputted:
227
+ # inputs:
228
+ # ['features': SparkML vector (required)]
229
+ # outputs:
230
+ # ['prediction': SparkML vector (required)]
231
+ print(model_info.signature)
232
+
233
+ loaded = mlflow.pyfunc.load_model(model_info.model_uri)
234
+
235
+ test_dataset = pd.DataFrame({"features": [[1.0, 2.0]]})
236
+
237
+ # `loaded.predict` accepts `Array[double]` type input column,
238
+ # and generates `Array[double]` type output column.
239
+ print(loaded.predict(test_dataset))
240
+
241
+ input_example: {{ input_example }}
242
+ await_registration_for: Number of seconds to wait for the model version to finish
243
+ being created and is in ``READY`` status. By default, the function
244
+ waits for five minutes. Specify 0 or None to skip waiting.
245
+ pip_requirements: {{ pip_requirements }}
246
+ extra_pip_requirements: {{ extra_pip_requirements }}
247
+ metadata: {{ metadata }}
248
+
249
+ Returns:
250
+ A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
251
+ metadata of the logged model.
252
+
253
+ .. code-block:: python
254
+ :caption: Example
255
+
256
+ from pyspark.ml import Pipeline
257
+ from pyspark.ml.classification import LogisticRegression
258
+ from pyspark.ml.feature import HashingTF, Tokenizer
259
+
260
+ training = spark.createDataFrame(
261
+ [
262
+ (0, "a b c d e spark", 1.0),
263
+ (1, "b d", 0.0),
264
+ (2, "spark f g h", 1.0),
265
+ (3, "hadoop mapreduce", 0.0),
266
+ ],
267
+ ["id", "text", "label"],
268
+ )
269
+ tokenizer = Tokenizer(inputCol="text", outputCol="words")
270
+ hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
271
+ lr = LogisticRegression(maxIter=10, regParam=0.001)
272
+ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
273
+ model = pipeline.fit(training)
274
+ mlflow.spark.log_model(model, "spark-model")
275
+ """
276
+ _validate_model(spark_model)
277
+ from pyspark.ml import PipelineModel
278
+
279
+ if _is_spark_connect_model(spark_model):
280
+ # TODO: Use `Model.log` once `mlflowdbfs` supports logged model artifacts.
281
+ # `mlflowdbfs` doesn't support logged model artifacts yet, so we use `Model._log_v2`.
282
+ return Model._log_v2(
283
+ artifact_path=artifact_path,
284
+ flavor=mlflow.spark,
285
+ spark_model=spark_model,
286
+ conda_env=conda_env,
287
+ code_paths=code_paths,
288
+ registered_model_name=registered_model_name,
289
+ signature=signature,
290
+ input_example=input_example,
291
+ await_registration_for=await_registration_for,
292
+ pip_requirements=pip_requirements,
293
+ extra_pip_requirements=extra_pip_requirements,
294
+ metadata=metadata,
295
+ )
296
+
297
+ if not isinstance(spark_model, PipelineModel):
298
+ spark_model = PipelineModel([spark_model])
299
+ run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id
300
+ run_root_artifact_uri = mlflow.get_artifact_uri()
301
+ remote_model_path = None
302
+ if _should_use_mlflowdbfs(run_root_artifact_uri):
303
+ remote_model_path = append_to_uri_path(
304
+ run_root_artifact_uri, artifact_path, _SPARK_MODEL_PATH_SUB
305
+ )
306
+ mlflowdbfs_path = _mlflowdbfs_path(run_id, artifact_path)
307
+ with databricks_utils.MlflowCredentialContext(
308
+ get_databricks_profile_uri_from_artifact_uri(run_root_artifact_uri)
309
+ ):
310
+ try:
311
+ spark_model.save(mlflowdbfs_path)
312
+ except Exception as e:
313
+ raise MlflowException("failed to save spark model via mlflowdbfs") from e
314
+
315
+ # If the artifact URI is a local filesystem path, defer to Model.log() to persist the model,
316
+ # since Spark may not be able to write directly to the driver's filesystem. For example,
317
+ # writing to `file:/uri` will write to the local filesystem from each executor, which will
318
+ # be incorrect on multi-node clusters.
319
+ # If the artifact URI is not a local filesystem path we attempt to write directly to the
320
+ # artifact repo via Spark. If this fails, we defer to Model.log().
321
+ elif (
322
+ is_local_uri(run_root_artifact_uri)
323
+ or databricks_utils.is_in_databricks_serverless_runtime()
324
+ or databricks_utils.is_in_databricks_shared_cluster_runtime()
325
+ or not _maybe_save_model(
326
+ spark_model,
327
+ append_to_uri_path(run_root_artifact_uri, artifact_path),
328
+ )
329
+ ):
330
+ dfs_tmpdir = dfs_tmpdir or MLFLOW_DFS_TMP.get()
331
+ _check_databricks_uc_volume_tmpdir_availability(dfs_tmpdir)
332
+ # TODO: Use `Model.log` once `mlflowdbfs` supports logged model artifacts.
333
+ # `mlflowdbfs` doesn't support logged model artifacts yet, so we use `Model._log_v2`.
334
+ return Model._log_v2(
335
+ artifact_path=artifact_path,
336
+ flavor=mlflow.spark,
337
+ spark_model=spark_model,
338
+ conda_env=conda_env,
339
+ code_paths=code_paths,
340
+ dfs_tmpdir=dfs_tmpdir,
341
+ registered_model_name=registered_model_name,
342
+ signature=signature,
343
+ input_example=input_example,
344
+ await_registration_for=await_registration_for,
345
+ pip_requirements=pip_requirements,
346
+ extra_pip_requirements=extra_pip_requirements,
347
+ metadata=metadata,
348
+ )
349
+ # Otherwise, override the default model log behavior and save model directly to artifact repo
350
+ mlflow_model = Model(artifact_path=artifact_path, run_id=run_id)
351
+ with TempDir() as tmp:
352
+ tmp_model_metadata_dir = tmp.path()
353
+ _save_model_metadata(
354
+ tmp_model_metadata_dir,
355
+ spark_model,
356
+ mlflow_model,
357
+ conda_env,
358
+ code_paths,
359
+ signature=signature,
360
+ input_example=input_example,
361
+ pip_requirements=pip_requirements,
362
+ extra_pip_requirements=extra_pip_requirements,
363
+ remote_model_path=remote_model_path,
364
+ )
365
+ mlflow.tracking.fluent.log_artifacts(tmp_model_metadata_dir, artifact_path)
366
+ mlflow.tracking.fluent._record_logged_model(mlflow_model)
367
+ if registered_model_name is not None:
368
+ mlflow.register_model(
369
+ f"runs:/{run_id}/{artifact_path}",
370
+ registered_model_name,
371
+ await_registration_for,
372
+ )
373
+ return mlflow_model.get_model_info()
374
+
375
+
376
+ def _mlflowdbfs_path(run_id, artifact_path):
377
+ if artifact_path.startswith("/"):
378
+ raise MlflowException(
379
+ f"artifact_path should be relative, found: {artifact_path}",
380
+ INVALID_PARAMETER_VALUE,
381
+ )
382
+ return "{}:///artifacts?run_id={}&path=/{}".format(
383
+ _MLFLOWDBFS_SCHEME, run_id, posixpath.join(artifact_path, _SPARK_MODEL_PATH_SUB)
384
+ )
385
+
386
+
387
+ def _maybe_save_model(spark_model, model_dir):
388
+ from py4j.protocol import Py4JError
389
+
390
+ try:
391
+ spark_model.save(posixpath.join(model_dir, _SPARK_MODEL_PATH_SUB))
392
+ return True
393
+ except Py4JError:
394
+ return False
395
+
396
+
397
+ class _HadoopFileSystem:
398
+ """
399
+ Interface to org.apache.hadoop.fs.FileSystem.
400
+
401
+ Spark ML models expect to read from and write to Hadoop FileSystem when running on a cluster.
402
+ Since MLflow works on local directories, we need this interface to copy the files between
403
+ the current DFS and local dir.
404
+ """
405
+
406
+ def __init__(self):
407
+ raise Exception("This class should not be instantiated")
408
+
409
+ _filesystem = None
410
+ _conf = None
411
+
412
+ @classmethod
413
+ def _jvm(cls):
414
+ from pyspark import SparkContext
415
+
416
+ return SparkContext._gateway.jvm
417
+
418
+ @classmethod
419
+ def _fs(cls):
420
+ if not cls._filesystem:
421
+ cls._filesystem = cls._jvm().org.apache.hadoop.fs.FileSystem.get(cls._conf())
422
+ return cls._filesystem
423
+
424
+ @classmethod
425
+ def _conf(cls):
426
+ from pyspark import SparkContext
427
+
428
+ sc = SparkContext.getOrCreate()
429
+ return sc._jsc.hadoopConfiguration()
430
+
431
+ @classmethod
432
+ def _local_path(cls, path):
433
+ return cls._jvm().org.apache.hadoop.fs.Path(os.path.abspath(path))
434
+
435
+ @classmethod
436
+ def _remote_path(cls, path):
437
+ return cls._jvm().org.apache.hadoop.fs.Path(path)
438
+
439
+ @classmethod
440
+ def _stats(cls):
441
+ return cls._jvm().org.apache.hadoop.fs.FileSystem.getGlobalStorageStatistics()
442
+
443
+ @classmethod
444
+ def copy_to_local_file(cls, src, dst, remove_src):
445
+ cls._fs().copyToLocalFile(remove_src, cls._remote_path(src), cls._local_path(dst))
446
+
447
+ @classmethod
448
+ def copy_from_local_file(cls, src, dst, remove_src):
449
+ cls._fs().copyFromLocalFile(remove_src, cls._local_path(src), cls._remote_path(dst))
450
+
451
+ @classmethod
452
+ def qualified_local_path(cls, path):
453
+ return cls._fs().makeQualified(cls._local_path(path)).toString()
454
+
455
+ @classmethod
456
+ def maybe_copy_from_local_file(cls, src, dst):
457
+ """
458
+ Conditionally copy the file to the Hadoop DFS.
459
+ The file is copied iff the configuration has distributed filesystem.
460
+
461
+ Returns:
462
+ If copied, return new target location, otherwise return (absolute) source path.
463
+ """
464
+ local_path = cls._local_path(src)
465
+ qualified_local_path = cls._fs().makeQualified(local_path).toString()
466
+ if qualified_local_path == "file:" + local_path.toString():
467
+ return local_path.toString()
468
+ cls.copy_from_local_file(src, dst, remove_src=False)
469
+ _logger.info("Copied SparkML model to %s", dst)
470
+ return dst
471
+
472
+ @classmethod
473
+ def _try_file_exists(cls, dfs_path):
474
+ try:
475
+ return cls._fs().exists(dfs_path)
476
+ except Exception as ex:
477
+ # Log a debug-level message, since existence checks may raise exceptions
478
+ # in normal operating circumstances that do not warrant warnings
479
+ _logger.debug(
480
+ "Unexpected exception while checking if model uri is visible on DFS: %s", ex
481
+ )
482
+ return False
483
+
484
+ @classmethod
485
+ def maybe_copy_from_uri(cls, src_uri, dst_path, local_model_path=None):
486
+ """
487
+ Conditionally copy the file to the Hadoop DFS from the source uri.
488
+ In case the file is already on the Hadoop DFS do nothing.
489
+
490
+ Returns:
491
+ If copied, return new target location, otherwise return source uri.
492
+ """
493
+ try:
494
+ # makeQualified throws if wrong schema / uri
495
+ dfs_path = cls._fs().makeQualified(cls._remote_path(src_uri))
496
+ if cls._try_file_exists(dfs_path):
497
+ _logger.info("File '%s' is already on DFS, copy is not necessary.", src_uri)
498
+ return src_uri
499
+ except Exception:
500
+ _logger.info("URI '%s' does not point to the current DFS.", src_uri)
501
+ _logger.info("File '%s' not found on DFS. Will attempt to upload the file.", src_uri)
502
+ return cls.maybe_copy_from_local_file(
503
+ local_model_path or _download_artifact_from_uri(src_uri), dst_path
504
+ )
505
+
506
+ @classmethod
507
+ def delete(cls, path):
508
+ cls._fs().delete(cls._remote_path(path), True)
509
+
510
+ @classmethod
511
+ def is_filesystem_available(cls, scheme):
512
+ return scheme in [stats.getScheme() for stats in cls._stats().iterator()]
513
+
514
+
515
+ def _should_use_mlflowdbfs(root_uri):
516
+ # The `mlflowdbfs` scheme does not appear in the available schemes returned from
517
+ # the Hadoop FileSystem API until a read call has been issued.
518
+ from mlflow.utils._spark_utils import _get_active_spark_session
519
+
520
+ if (
521
+ databricks_utils.is_in_databricks_serverless_runtime()
522
+ or databricks_utils.is_in_databricks_shared_cluster_runtime()
523
+ or not is_valid_dbfs_uri(root_uri)
524
+ or not is_databricks_acled_artifacts_uri(root_uri)
525
+ or not databricks_utils.is_in_databricks_runtime()
526
+ or (environment_variables._DISABLE_MLFLOWDBFS.get() or "").lower() == "true"
527
+ ):
528
+ return False
529
+
530
+ try:
531
+ databricks_utils._get_dbutils()
532
+ except Exception:
533
+ # If dbutils is unavailable, indicate that mlflowdbfs is unavailable
534
+ # because usage of mlflowdbfs depends on dbutils
535
+ return False
536
+
537
+ mlflowdbfs_read_exception_str = None
538
+ try:
539
+ _get_active_spark_session().read.load("mlflowdbfs:///artifact?run_id=foo&path=/bar")
540
+ except Exception as e:
541
+ # The load invocation is expected to throw an exception.
542
+ mlflowdbfs_read_exception_str = str(e)
543
+
544
+ try:
545
+ return _HadoopFileSystem.is_filesystem_available(_MLFLOWDBFS_SCHEME)
546
+ except Exception:
547
+ # The HDFS filesystem logic used to determine mlflowdbfs availability on Databricks
548
+ # clusters may not work on certain Databricks cluster types due to unavailability of
549
+ # the _HadoopFileSystem.is_filesystem_available() API. As a temporary workaround,
550
+ # we check the contents of the expected exception raised by a dummy mlflowdbfs
551
+ # read for evidence that mlflowdbfs is available. If "MlflowdbfsClient" is present
552
+ # in the exception contents, we can safely assume that mlflowdbfs is available because
553
+ # `MlflowdbfsClient` is exclusively used by mlflowdbfs for performing MLflow
554
+ # file storage operations
555
+ #
556
+ # TODO: Remove this logic once the _HadoopFileSystem.is_filesystem_available() check
557
+ # below is determined to work on all Databricks cluster types
558
+ return "MlflowdbfsClient" in (mlflowdbfs_read_exception_str or "")
559
+
560
+
561
+ def _save_model_metadata(
562
+ dst_dir,
563
+ spark_model,
564
+ mlflow_model,
565
+ conda_env,
566
+ code_paths,
567
+ signature=None,
568
+ input_example=None,
569
+ pip_requirements=None,
570
+ extra_pip_requirements=None,
571
+ remote_model_path=None,
572
+ ):
573
+ """
574
+ Saves model metadata into the passed-in directory.
575
+ If mlflowdbfs is not used, the persisted metadata assumes that a model can be
576
+ loaded from a relative path to the metadata file (currently hard-coded to "sparkml").
577
+ If mlflowdbfs is used, remote_model_path should be provided, and the model needs to
578
+ be loaded from the remote_model_path.
579
+ """
580
+ import pyspark
581
+
582
+ is_spark_connect_model = _is_spark_connect_model(spark_model)
583
+ if signature is not None:
584
+ mlflow_model.signature = signature
585
+ if input_example is not None:
586
+ _save_example(mlflow_model, input_example, dst_dir)
587
+
588
+ code_dir_subpath = _validate_and_copy_code_paths(code_paths, dst_dir)
589
+ mlflow_model.add_flavor(
590
+ FLAVOR_NAME,
591
+ pyspark_version=pyspark.__version__,
592
+ model_data=_SPARK_MODEL_PATH_SUB,
593
+ code=code_dir_subpath,
594
+ model_class=_get_fully_qualified_class_name(spark_model),
595
+ )
596
+ pyfunc.add_to_model(
597
+ mlflow_model,
598
+ loader_module="mlflow.spark",
599
+ data=_SPARK_MODEL_PATH_SUB,
600
+ conda_env=_CONDA_ENV_FILE_NAME,
601
+ python_env=_PYTHON_ENV_FILE_NAME,
602
+ code=code_dir_subpath,
603
+ )
604
+ if size := get_total_file_size(dst_dir):
605
+ mlflow_model.model_size_bytes = size
606
+ mlflow_model.save(os.path.join(dst_dir, MLMODEL_FILE_NAME))
607
+
608
+ if conda_env is None:
609
+ if pip_requirements is None:
610
+ default_reqs = get_default_pip_requirements(is_spark_connect_model)
611
+ if remote_model_path:
612
+ _logger.info(
613
+ "Inferring pip requirements by reloading the logged model from the databricks "
614
+ "artifact repository, which can be time-consuming. To speed up, explicitly "
615
+ "specify the conda_env or pip_requirements when calling log_model()."
616
+ )
617
+ # To ensure `_load_pyfunc` can successfully load the model during the dependency
618
+ # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
619
+ inferred_reqs = mlflow.models.infer_pip_requirements(
620
+ remote_model_path or dst_dir,
621
+ FLAVOR_NAME,
622
+ fallback=default_reqs,
623
+ )
624
+ default_reqs = sorted(set(inferred_reqs).union(default_reqs))
625
+ else:
626
+ default_reqs = None
627
+ conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
628
+ default_reqs,
629
+ pip_requirements,
630
+ extra_pip_requirements,
631
+ )
632
+ else:
633
+ conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
634
+
635
+ with open(os.path.join(dst_dir, _CONDA_ENV_FILE_NAME), "w") as f:
636
+ yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
637
+
638
+ # Save `constraints.txt` if necessary
639
+ if pip_constraints:
640
+ write_to(os.path.join(dst_dir, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
641
+
642
+ # Save `requirements.txt`
643
+ write_to(os.path.join(dst_dir, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
644
+
645
+ _PythonEnv.current().to_yaml(os.path.join(dst_dir, _PYTHON_ENV_FILE_NAME))
646
+
647
+
648
+ def _validate_model(spark_model):
649
+ from pyspark.ml import Model as PySparkModel
650
+ from pyspark.ml import Transformer as PySparkTransformer
651
+ from pyspark.ml.util import MLReadable, MLWritable
652
+
653
+ if _is_spark_connect_model(spark_model):
654
+ return
655
+
656
+ if (
657
+ (
658
+ not isinstance(spark_model, PySparkModel)
659
+ and not isinstance(spark_model, PySparkTransformer)
660
+ )
661
+ or not isinstance(spark_model, MLReadable)
662
+ or not isinstance(spark_model, MLWritable)
663
+ ):
664
+ raise MlflowException(
665
+ "Cannot serialize this model. MLflow can only save descendants of pyspark.ml.Model "
666
+ "or pyspark.ml.Transformer that implement MLWritable and MLReadable.",
667
+ INVALID_PARAMETER_VALUE,
668
+ )
669
+
670
+
671
+ def _is_spark_connect_model(spark_model):
672
+ """
673
+ Return whether the spark model is spark connect ML model
674
+ """
675
+ try:
676
+ from pyspark.ml.connect import Model as ConnectModel
677
+
678
+ return isinstance(spark_model, ConnectModel)
679
+ except ImportError:
680
+ # pyspark < 3.5 does not support Spark connect ML model
681
+ return False
682
+
683
+
684
+ def _is_uc_volume_uri(url):
685
+ parsed_url = urlparse(url)
686
+ return parsed_url.scheme in ["", "dbfs"] and parsed_url.path.startswith("/Volumes")
687
+
688
+
689
+ def _check_databricks_uc_volume_tmpdir_availability(dfs_tmpdir):
690
+ if (
691
+ databricks_utils.is_in_databricks_serverless_runtime()
692
+ or databricks_utils.is_in_databricks_shared_cluster_runtime()
693
+ ):
694
+ if not dfs_tmpdir or not _is_uc_volume_uri(dfs_tmpdir):
695
+ raise MlflowException(
696
+ "UC volume path must be provided to save, log or load SparkML models "
697
+ "in Databricks shared or serverless clusters. "
698
+ "Specify environment variable 'MLFLOW_DFS_TMP' "
699
+ "or 'dfs_tmpdir' argument that uses a UC volume path starting with '/Volumes/...' "
700
+ "when saving, logging or loading a model."
701
+ )
702
+
703
+
704
+ @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="pyspark"))
705
+ def save_model(
706
+ spark_model,
707
+ path,
708
+ mlflow_model=None,
709
+ conda_env=None,
710
+ code_paths=None,
711
+ dfs_tmpdir=None,
712
+ signature: ModelSignature = None,
713
+ input_example: ModelInputExample = None,
714
+ pip_requirements=None,
715
+ extra_pip_requirements=None,
716
+ metadata=None,
717
+ ):
718
+ """
719
+ Save a Spark MLlib Model to a local path.
720
+
721
+ By default, this function saves models using the Spark MLlib persistence mechanism.
722
+
723
+ Args:
724
+ spark_model: Spark model to be saved - MLflow can only save descendants of
725
+ pyspark.ml.Model or pyspark.ml.Transformer which implement
726
+ MLReadable and MLWritable.
727
+ path: Local path where the model is to be saved.
728
+ mlflow_model: MLflow model config this flavor is being added to.
729
+ conda_env: {{ conda_env }}
730
+ code_paths: {{ code_paths }}
731
+ dfs_tmpdir: Temporary directory path on Distributed (Hadoop) File System (DFS) or local
732
+ filesystem if running in local mode. The model is be written in this
733
+ destination and then copied to the requested local path. This is necessary
734
+ as Spark ML models read from and write to DFS if running on a cluster. All
735
+ temporary files created on the DFS are removed if this operation
736
+ completes successfully. Defaults to ``/tmp/mlflow``.
737
+ signature: See the document of argument ``signature`` in :py:func:`mlflow.spark.log_model`.
738
+ input_example: {{ input_example }}
739
+ pip_requirements: {{ pip_requirements }}
740
+ extra_pip_requirements: {{ extra_pip_requirements }}
741
+ metadata: {{ metadata }}
742
+
743
+ .. code-block:: python
744
+ :caption: Example
745
+
746
+ from mlflow import spark
747
+ from pyspark.ml.pipeline import PipelineModel
748
+
749
+ # your pyspark.ml.pipeline.PipelineModel type
750
+ model = ...
751
+ mlflow.spark.save_model(model, "spark-model")
752
+ """
753
+ _validate_model(spark_model)
754
+ _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
755
+
756
+ from pyspark.ml import PipelineModel
757
+
758
+ from mlflow.utils._spark_utils import _get_active_spark_session
759
+
760
+ is_spark_connect_model = _is_spark_connect_model(spark_model)
761
+
762
+ if not is_spark_connect_model and not isinstance(spark_model, PipelineModel):
763
+ spark_model = PipelineModel([spark_model])
764
+ if mlflow_model is None:
765
+ mlflow_model = Model()
766
+ if metadata is not None:
767
+ mlflow_model.metadata = metadata
768
+
769
+ # for automatic signature inference, we use an inline implementation rather than the
770
+ # `_infer_signature_from_input_example` API because we need to convert model predictions from a
771
+ # list into a Pandas series for signature inference.
772
+ if signature is None and input_example is not None:
773
+ input_ex = _Example(input_example).inference_data
774
+ try:
775
+ spark = _get_active_spark_session()
776
+ if spark is not None:
777
+ input_example_spark_df = spark.createDataFrame(input_ex)
778
+ # `_infer_spark_model_signature` mutates the model. Copy the model to preserve the
779
+ # original model.
780
+ try:
781
+ spark_model = spark_model.copy()
782
+ except Exception:
783
+ _logger.debug(
784
+ "Failed to copy the model, using the original model.", exc_info=True
785
+ )
786
+ signature = mlflow.pyspark.ml._infer_spark_model_signature(
787
+ spark_model, input_example_spark_df
788
+ )
789
+ except Exception as e:
790
+ if environment_variables._MLFLOW_TESTING.get():
791
+ raise
792
+ _logger.warning(_LOG_MODEL_INFER_SIGNATURE_WARNING_TEMPLATE, repr(e))
793
+ _logger.debug("", exc_info=True)
794
+ elif signature is False:
795
+ signature = None
796
+
797
+ sparkml_data_path = os.path.abspath(os.path.join(path, _SPARK_MODEL_PATH_SUB))
798
+
799
+ if is_spark_connect_model:
800
+ spark_model.saveToLocal(sparkml_data_path)
801
+ else:
802
+ # Spark ML stores the model on DFS if running on a cluster
803
+ # Save it to a DFS temp dir first and copy it to local path
804
+ if dfs_tmpdir is None:
805
+ dfs_tmpdir = MLFLOW_DFS_TMP.get()
806
+
807
+ _check_databricks_uc_volume_tmpdir_availability(dfs_tmpdir)
808
+ tmp_path = generate_tmp_dfs_path(dfs_tmpdir)
809
+ spark_model.save(tmp_path)
810
+
811
+ if databricks_utils.is_in_databricks_runtime() and _is_uc_volume_uri(tmp_path):
812
+ # The temp DFS path is a UC volume path.
813
+ # Use UC volume fuse mount to read data.
814
+ tmp_path_fuse = urlparse(tmp_path).path
815
+ shutil.move(src=tmp_path_fuse, dst=sparkml_data_path)
816
+ else:
817
+ # We're copying the Spark model from DBFS to the local filesystem if (a) the temporary
818
+ # DFS URI we saved the Spark model to is a DBFS URI ("dbfs:/my-directory"), or (b) if
819
+ # we're running on a Databricks cluster and the URI is schemeless (e.g. looks like a
820
+ # filesystem absolute path like "/my-directory")
821
+ copying_from_dbfs = is_valid_dbfs_uri(tmp_path) or (
822
+ databricks_utils.is_in_cluster() and posixpath.abspath(tmp_path) == tmp_path
823
+ )
824
+ if copying_from_dbfs and databricks_utils.is_dbfs_fuse_available():
825
+ tmp_path_fuse = dbfs_hdfs_uri_to_fuse_path(tmp_path)
826
+ shutil.move(src=tmp_path_fuse, dst=sparkml_data_path)
827
+ else:
828
+ _HadoopFileSystem.copy_to_local_file(tmp_path, sparkml_data_path, remove_src=True)
829
+
830
+ _save_model_metadata(
831
+ dst_dir=path,
832
+ spark_model=spark_model,
833
+ mlflow_model=mlflow_model,
834
+ conda_env=conda_env,
835
+ code_paths=code_paths,
836
+ signature=signature,
837
+ input_example=input_example,
838
+ pip_requirements=pip_requirements,
839
+ extra_pip_requirements=extra_pip_requirements,
840
+ )
841
+
842
+
843
+ def _load_model_databricks_dbfs(dfs_tmpdir, local_model_path):
844
+ from pyspark.ml.pipeline import PipelineModel
845
+
846
+ # Spark ML expects the model to be stored on DFS
847
+ # Copy the model to a temp DFS location first. We cannot delete this file, as
848
+ # Spark may read from it at any point.
849
+ fuse_dfs_tmpdir = dbfs_hdfs_uri_to_fuse_path(dfs_tmpdir)
850
+ os.makedirs(fuse_dfs_tmpdir)
851
+ # Workaround for inability to use shutil.copytree with DBFS FUSE due to permission-denied
852
+ # errors on passthrough-enabled clusters when attempting to copy permission bits for directories
853
+ shutil_copytree_without_file_permissions(src_dir=local_model_path, dst_dir=fuse_dfs_tmpdir)
854
+ return PipelineModel.load(dfs_tmpdir)
855
+
856
+
857
+ def _load_model_databricks_uc_volume(dfs_tmpdir, local_model_path):
858
+ from pyspark.ml.pipeline import PipelineModel
859
+
860
+ # Copy the model to a temp DFS location first. We cannot delete this file, as
861
+ # Spark may read from it at any point.
862
+ fuse_dfs_tmpdir = urlparse(dfs_tmpdir).path
863
+ shutil.copytree(src=local_model_path, dst=fuse_dfs_tmpdir)
864
+ return PipelineModel.load(dfs_tmpdir)
865
+
866
+
867
+ def _load_model(model_uri, dfs_tmpdir_base=None, local_model_path=None):
868
+ from pyspark.ml.pipeline import PipelineModel
869
+
870
+ dfs_tmpdir = generate_tmp_dfs_path(dfs_tmpdir_base or MLFLOW_DFS_TMP.get())
871
+
872
+ _check_databricks_uc_volume_tmpdir_availability(dfs_tmpdir)
873
+ if (
874
+ databricks_utils.is_in_databricks_serverless_runtime()
875
+ or databricks_utils.is_in_databricks_shared_cluster_runtime()
876
+ ):
877
+ return _load_model_databricks_uc_volume(
878
+ dfs_tmpdir, local_model_path or _download_artifact_from_uri(model_uri)
879
+ )
880
+ if databricks_utils.is_in_cluster() and databricks_utils.is_dbfs_fuse_available():
881
+ return _load_model_databricks_dbfs(
882
+ dfs_tmpdir, local_model_path or _download_artifact_from_uri(model_uri)
883
+ )
884
+ model_uri = _HadoopFileSystem.maybe_copy_from_uri(model_uri, dfs_tmpdir, local_model_path)
885
+ return PipelineModel.load(model_uri)
886
+
887
+
888
+ def _load_spark_connect_model(model_class, local_path):
889
+ return _get_class_from_string(model_class).loadFromLocal(local_path)
890
+
891
+
892
+ def load_model(model_uri, dfs_tmpdir=None, dst_path=None):
893
+ """
894
+ Load the Spark MLlib model from the path.
895
+
896
+ Args:
897
+ model_uri: The location, in URI format, of the MLflow model, for example:
898
+
899
+ - ``/Users/me/path/to/local/model``
900
+ - ``relative/path/to/local/model``
901
+ - ``s3://my_bucket/path/to/model``
902
+ - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
903
+ - ``models:/<model_name>/<model_version>``
904
+ - ``models:/<model_name>/<stage>``
905
+
906
+ For more information about supported URI schemes, see
907
+ `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
908
+ artifact-locations>`_.
909
+ dfs_tmpdir: Temporary directory path on Distributed (Hadoop) File System (DFS) or local
910
+ filesystem if running in local mode. The model is loaded from this
911
+ destination. Defaults to ``/tmp/mlflow``.
912
+ dst_path: The local filesystem path to which to download the model artifact.
913
+ This directory must already exist. If unspecified, a local output
914
+ path will be created.
915
+
916
+ Returns:
917
+ pyspark.ml.pipeline.PipelineModel
918
+
919
+ .. code-block:: python
920
+ :caption: Example
921
+
922
+ from mlflow import spark
923
+
924
+ model = mlflow.spark.load_model("spark-model")
925
+ # Prepare test documents, which are unlabeled (id, text) tuples.
926
+ test = spark.createDataFrame(
927
+ [(4, "spark i j k"), (5, "l m n"), (6, "spark hadoop spark"), (7, "apache hadoop")],
928
+ ["id", "text"],
929
+ )
930
+ # Make predictions on test documents
931
+ prediction = model.transform(test)
932
+ """
933
+ # This MUST be called prior to appending the model flavor to `model_uri` in order
934
+ # for `artifact_path` to take on the correct value for model loading via mlflowdbfs.
935
+ root_uri, artifact_path = _get_root_uri_and_artifact_path(model_uri)
936
+
937
+ local_mlflow_model_path = _download_artifact_from_uri(
938
+ artifact_uri=model_uri, output_path=dst_path
939
+ )
940
+ flavor_conf = Model.load(local_mlflow_model_path).flavors[FLAVOR_NAME]
941
+ _add_code_from_conf_to_system_path(local_mlflow_model_path, flavor_conf)
942
+
943
+ model_class = flavor_conf.get("model_class")
944
+ if model_class is not None and model_class.startswith("pyspark.ml.connect."):
945
+ spark_model_local_path = os.path.join(local_mlflow_model_path, flavor_conf["model_data"])
946
+ return _load_spark_connect_model(model_class, spark_model_local_path)
947
+
948
+ if _should_use_mlflowdbfs(model_uri) and (
949
+ run_id := DatabricksArtifactRepository._extract_run_id(model_uri)
950
+ ):
951
+ from pyspark.ml.pipeline import PipelineModel
952
+
953
+ mlflowdbfs_path = _mlflowdbfs_path(run_id, artifact_path)
954
+ with databricks_utils.MlflowCredentialContext(
955
+ get_databricks_profile_uri_from_artifact_uri(root_uri)
956
+ ):
957
+ return PipelineModel.load(mlflowdbfs_path)
958
+
959
+ sparkml_model_uri = append_to_uri_path(model_uri, flavor_conf["model_data"])
960
+ local_sparkml_model_path = os.path.join(local_mlflow_model_path, flavor_conf["model_data"])
961
+ return _load_model(
962
+ model_uri=sparkml_model_uri,
963
+ dfs_tmpdir_base=dfs_tmpdir,
964
+ local_model_path=local_sparkml_model_path,
965
+ )
966
+
967
+
968
+ def _load_pyfunc(path):
969
+ """
970
+ Load PyFunc implementation. Called by ``pyfunc.load_model``.
971
+
972
+ Args:
973
+ path: Local filesystem path to the MLflow Model with the ``spark`` flavor.
974
+ """
975
+ from mlflow.utils._spark_utils import (
976
+ _create_local_spark_session_for_loading_spark_model,
977
+ _get_active_spark_session,
978
+ )
979
+
980
+ model_meta_path = os.path.join(os.path.dirname(path), MLMODEL_FILE_NAME)
981
+ model_meta = Model.load(model_meta_path)
982
+
983
+ model_class = model_meta.flavors[FLAVOR_NAME].get("model_class")
984
+ if model_class is not None and model_class.startswith("pyspark.ml.connect."):
985
+ # Note:
986
+ # Spark connect ML models don't require a spark session for running inference.
987
+ spark = None
988
+ spark_model = _load_spark_connect_model(model_class, path)
989
+
990
+ else:
991
+ # NOTE: The `_create_local_spark_session_for_loading_spark_model()` call below may change
992
+ # settings of the active session which we do not intend to do here.
993
+ # In particular, setting master to local[1] can break distributed clusters.
994
+ # To avoid this problem, we explicitly check for an active session. This is not ideal but
995
+ # there is no good workaround at the moment.
996
+ spark = _get_active_spark_session()
997
+ if spark is None:
998
+ # NB: If there is no existing Spark context, create a new local one.
999
+ # NB: We're disabling caching on the new context since we do not need it and we want to
1000
+ # avoid overwriting cache of underlying Spark cluster when executed on a Spark Worker
1001
+ # (e.g. as part of spark_udf).
1002
+ spark = _create_local_spark_session_for_loading_spark_model()
1003
+
1004
+ spark_model = _load_model(model_uri=path)
1005
+
1006
+ return _PyFuncModelWrapper(spark, spark_model, signature=model_meta.signature)
1007
+
1008
+
1009
+ def _find_and_set_features_col_as_vector_if_needed(spark_df, spark_model):
1010
+ """
1011
+ Finds the `featuresCol` column in spark_model and
1012
+ then tries to cast that column to `vector` type.
1013
+ This method is noop if the `featuresCol` is already of type `vector`
1014
+ or if it can't be cast to `vector` type
1015
+ Note:
1016
+ If a spark ML pipeline contains a single Estimator stage, it requires
1017
+ the input dataframe to contain features column of vector type.
1018
+ But the autologging for pyspark ML casts vector column to array<double> type
1019
+ for parity with the pd Dataframe. The following fix is required, which transforms
1020
+ that features column back to vector type so that the pipeline stages can correctly work.
1021
+ A valid scenario is if the auto-logged input example is directly used
1022
+ for prediction, which would otherwise fail without this transformation.
1023
+
1024
+ Args:
1025
+ spark_df: Input dataframe that contains `featuresCol`
1026
+ spark_model: A pipeline model or a single transformer that contains `featuresCol` param
1027
+
1028
+ Returns:
1029
+ A spark dataframe that contains features column of `vector` type.
1030
+ """
1031
+ from pyspark.ml.linalg import Vectors, VectorUDT
1032
+ from pyspark.sql import types as t
1033
+ from pyspark.sql.functions import udf
1034
+
1035
+ def _find_stage_with_features_col(stage):
1036
+ if stage.hasParam("featuresCol"):
1037
+
1038
+ def _array_to_vector(input_array):
1039
+ return Vectors.dense(input_array)
1040
+
1041
+ array_to_vector_udf = udf(f=_array_to_vector, returnType=VectorUDT())
1042
+ features_col_name = stage.extractParamMap().get(stage.featuresCol)
1043
+ features_col_type = [
1044
+ _field
1045
+ for _field in spark_df.schema.fields
1046
+ if _field.name == features_col_name
1047
+ and _field.dataType
1048
+ in [t.ArrayType(t.DoubleType(), True), t.ArrayType(t.DoubleType(), False)]
1049
+ ]
1050
+ if len(features_col_type) == 1:
1051
+ return spark_df.withColumn(
1052
+ features_col_name, array_to_vector_udf(features_col_name)
1053
+ )
1054
+ return spark_df
1055
+
1056
+ if hasattr(spark_model, "stages"):
1057
+ for stage in reversed(spark_model.stages):
1058
+ return _find_stage_with_features_col(stage)
1059
+ return _find_stage_with_features_col(spark_model)
1060
+
1061
+
1062
+ class _PyFuncModelWrapper:
1063
+ """
1064
+ Wrapper around Spark MLlib PipelineModel providing interface for scoring pandas DataFrame.
1065
+ """
1066
+
1067
+ def __init__(self, spark, spark_model, signature):
1068
+ self.spark = spark
1069
+ self.spark_model = spark_model
1070
+ self.signature = signature
1071
+
1072
+ def get_raw_model(self):
1073
+ """
1074
+ Returns the underlying model.
1075
+ """
1076
+ return self.spark_model
1077
+
1078
+ def predict(
1079
+ self,
1080
+ pandas_df,
1081
+ params: Optional[dict[str, Any]] = None,
1082
+ ):
1083
+ """
1084
+ Generate predictions given input data in a pandas DataFrame.
1085
+
1086
+ Args:
1087
+ pandas_df: pandas DataFrame containing input data.
1088
+ params: Additional parameters to pass to the model for inference.
1089
+
1090
+ Returns:
1091
+ List with model predictions.
1092
+ """
1093
+ if _is_spark_connect_model(self.spark_model):
1094
+ # Spark connect ML model directly appends prediction result column to input pandas
1095
+ # dataframe. To make input dataframe intact, make a copy first.
1096
+ # TODO: apache/spark master has made a change to do shallow copy before
1097
+ # calling `spark_model.transform`, so once spark 4.0 releases, we can
1098
+ # remove this line.
1099
+ pandas_df = pandas_df.copy(deep=False)
1100
+ # Assuming the model output column name is "prediction".
1101
+ # Spark model uses "prediction" as default model inference output column name.
1102
+ return self.spark_model.transform(pandas_df)["prediction"]
1103
+
1104
+ # Convert List[np.float64] / np.array[np.float64] type to List[float] type,
1105
+ # otherwise it will break `spark.createDataFrame` column type inferring.
1106
+ if self.signature and self.signature.inputs:
1107
+ for col_spec in self.signature.inputs.inputs:
1108
+ if isinstance(col_spec.type, SparkMLVector):
1109
+ col_name = col_spec.name or pandas_df.columns[0]
1110
+
1111
+ pandas_df[col_name] = pandas_df[col_name].map(
1112
+ lambda array: [float(elem) for elem in array]
1113
+ )
1114
+
1115
+ spark_df = self.spark.createDataFrame(pandas_df)
1116
+
1117
+ # Convert Array[Double] column to spark ML vector type according to signature
1118
+ if self.signature and self.signature.inputs:
1119
+ for col_spec in self.signature.inputs.inputs:
1120
+ if isinstance(col_spec.type, SparkMLVector):
1121
+ from pyspark.ml.functions import array_to_vector
1122
+
1123
+ col_name = col_spec.name or spark_df.columns[0]
1124
+ spark_df = spark_df.withColumn(col_name, array_to_vector(col_name))
1125
+
1126
+ # For the case of no signature or signature logged by old version MLflow,
1127
+ # the signature does not support spark ML vector type, in this case,
1128
+ # automatically infer vector type input columns and do the conversion
1129
+ # using `_find_and_set_features_col_as_vector_if_needed` utility function.
1130
+ spark_df = _find_and_set_features_col_as_vector_if_needed(spark_df, self.spark_model)
1131
+
1132
+ prediction_column = mlflow.pyspark.ml._check_or_set_model_prediction_column(
1133
+ self.spark_model, spark_df
1134
+ )
1135
+ prediction_df = self.spark_model.transform(spark_df).select(prediction_column)
1136
+
1137
+ # If signature output schema exists and it contains vector type columns,
1138
+ # Convert spark ML vector type column to Array[Double] otherwise it will
1139
+ # break enforce_schema checking
1140
+ if self.signature and self.signature.outputs:
1141
+ for col_spec in self.signature.outputs.inputs:
1142
+ if isinstance(col_spec.type, SparkMLVector):
1143
+ from pyspark.ml.functions import vector_to_array
1144
+
1145
+ col_name = col_spec.name or prediction_df.columns[0]
1146
+ prediction_df = prediction_df.withColumn(col_name, vector_to_array(col_name))
1147
+ return [x.prediction for x in prediction_df.collect()]
1148
+
1149
+
1150
+ @autologging_integration(FLAVOR_NAME)
1151
+ def autolog(disable=False, silent=False):
1152
+ """
1153
+ Enables (or disables) and configures logging of Spark datasource paths, versions
1154
+ (if applicable), and formats when they are read. This method is not threadsafe and assumes a
1155
+ `SparkSession
1156
+ <https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.SparkSession.html>`_
1157
+ already exists with the
1158
+ `mlflow-spark JAR
1159
+ <https://www.mlflow.org/docs/latest/tracking.html#spark>`_
1160
+ attached. It should be called on the Spark driver, not on the executors (i.e. do not call
1161
+ this method within a function parallelized by Spark).
1162
+ The mlflow-spark JAR used must match the Scala version of Spark. Please see the
1163
+ `Maven Repository
1164
+ <https://mvnrepository.com/artifact/org.mlflow/mlflow-spark>`_
1165
+ for available versions. This API requires Spark 3.0 or above.
1166
+
1167
+ Datasource information is cached in memory and logged to all subsequent MLflow runs,
1168
+ including the active MLflow run (if one exists when the data is read). Note that autologging of
1169
+ Spark ML (MLlib) models is not currently supported via this API. Datasource autologging is
1170
+ best-effort, meaning that if Spark is under heavy load or MLflow logging fails for any reason
1171
+ (e.g., if the MLflow server is unavailable), logging may be dropped.
1172
+
1173
+ For any unexpected issues with autologging, check Spark driver and executor logs in addition
1174
+ to stderr & stdout generated from your MLflow code - datasource information is pulled from
1175
+ Spark, so logs relevant to debugging may show up amongst the Spark logs.
1176
+
1177
+ .. Note:: Spark datasource autologging only supports logging to MLflow runs in a single thread
1178
+
1179
+ .. code-block:: python
1180
+ :caption: Example
1181
+
1182
+ import mlflow.spark
1183
+ import os
1184
+ import shutil
1185
+ from pyspark.sql import SparkSession
1186
+
1187
+ # Create and persist some dummy data
1188
+ # Note: the 2.12 in 'org.mlflow:mlflow-spark_2.12:2.16.2' below indicates the Scala
1189
+ # version, please match this with that of Spark. The 2.16.2 indicates the mlflow version.
1190
+ # Note: On environments like Databricks with pre-created SparkSessions,
1191
+ # ensure the org.mlflow:mlflow-spark_2.12:2.16.2 is attached as a library to
1192
+ # your cluster
1193
+ spark = (
1194
+ SparkSession.builder.config(
1195
+ "spark.jars.packages",
1196
+ "org.mlflow:mlflow-spark_2.12:2.16.2",
1197
+ )
1198
+ .master("local[*]")
1199
+ .getOrCreate()
1200
+ )
1201
+ df = spark.createDataFrame(
1202
+ [(4, "spark i j k"), (5, "l m n"), (6, "spark hadoop spark"), (7, "apache hadoop")],
1203
+ ["id", "text"],
1204
+ )
1205
+ import tempfile
1206
+
1207
+ tempdir = tempfile.mkdtemp()
1208
+ df.write.csv(os.path.join(tempdir, "my-data-path"), header=True)
1209
+ # Enable Spark datasource autologging.
1210
+ mlflow.spark.autolog()
1211
+ loaded_df = spark.read.csv(
1212
+ os.path.join(tempdir, "my-data-path"), header=True, inferSchema=True
1213
+ )
1214
+ # Call toPandas() to trigger a read of the Spark datasource. Datasource info
1215
+ # (path and format) is logged to the current active run, or the
1216
+ # next-created MLflow run if no run is currently active
1217
+ with mlflow.start_run() as active_run:
1218
+ pandas_df = loaded_df.toPandas()
1219
+
1220
+ Args:
1221
+ disable: If ``True``, disables the Spark datasource autologging integration.
1222
+ If ``False``, enables the Spark datasource autologging integration.
1223
+ silent: If ``True``, suppress all event logs and warnings from MLflow during Spark
1224
+ datasource autologging. If ``False``, show all events and warnings during Spark
1225
+ datasource autologging.
1226
+ """
1227
+ from pyspark import __version__ as pyspark_version
1228
+ from pyspark.sql import SparkSession
1229
+
1230
+ from mlflow.spark.autologging import (
1231
+ _listen_for_spark_activity,
1232
+ _stop_listen_for_spark_activity,
1233
+ )
1234
+ from mlflow.utils import databricks_utils
1235
+ from mlflow.utils._spark_utils import _get_active_spark_session
1236
+
1237
+ if (
1238
+ databricks_utils.is_in_databricks_serverless_runtime()
1239
+ or databricks_utils.is_in_databricks_shared_cluster_runtime()
1240
+ ):
1241
+ if disable:
1242
+ return
1243
+ raise MlflowException(
1244
+ "MLflow Spark dataset autologging is not supported on Databricks shared clusters "
1245
+ "or Databricks serverless clusters."
1246
+ )
1247
+
1248
+ # Check if environment variable PYSPARK_PIN_THREAD is set to false.
1249
+ # The "Pin thread" concept was introduced since Pyspark 3.0.0 and set to default to true
1250
+ # since Pyspark 3.2.0 (https://issues.apache.org/jira/browse/SPARK-35303). When pin thread
1251
+ # is enabled, Pyspark manages Python and JVM threads in a 1:1, meaning that when one thread
1252
+ # is terminated, the corresponding thread in the other side will be terminated as well.
1253
+ # However, this causes an issue in spark autologging as our event listener thread may be
1254
+ # terminated before receiving the datasource event.
1255
+ # Hence, we have to disable it and decouple the thread management between Python and JVM.
1256
+ if (
1257
+ Version(pyspark_version) >= Version("3.2.0")
1258
+ and os.environ.get("PYSPARK_PIN_THREAD", "").lower() != "false"
1259
+ ):
1260
+ _logger.warning(
1261
+ "With Pyspark >= 3.2, PYSPARK_PIN_THREAD environment variable must be set to false "
1262
+ "for Spark datasource autologging to work."
1263
+ )
1264
+
1265
+ def __init__(original, self, *args, **kwargs):
1266
+ original(self, *args, **kwargs)
1267
+
1268
+ _listen_for_spark_activity(self._sc)
1269
+
1270
+ safe_patch(FLAVOR_NAME, SparkSession, "__init__", __init__, manage_run=False)
1271
+
1272
+ def patched_session_stop(original, self, *args, **kwargs):
1273
+ _stop_listen_for_spark_activity(self.sparkContext)
1274
+ original(self, *args, **kwargs)
1275
+
1276
+ safe_patch(FLAVOR_NAME, SparkSession, "stop", patched_session_stop, manage_run=False)
1277
+
1278
+ active_session = _get_active_spark_session()
1279
+ if active_session is not None:
1280
+ # We know SparkContext exists here already, so get it
1281
+ sc = active_session.sparkContext
1282
+
1283
+ if disable:
1284
+ _stop_listen_for_spark_activity(sc)
1285
+ else:
1286
+ _listen_for_spark_activity(sc)