snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +54 -42
  5. snowflake/ml/_internal/utils/service_logger.py +105 -3
  6. snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
  7. snowflake/ml/data/data_connector.py +13 -2
  8. snowflake/ml/data/data_ingestor.py +8 -0
  9. snowflake/ml/data/torch_utils.py +1 -1
  10. snowflake/ml/dataset/dataset.py +2 -1
  11. snowflake/ml/dataset/dataset_reader.py +14 -4
  12. snowflake/ml/experiment/__init__.py +3 -0
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/callback.py +121 -0
  20. snowflake/ml/experiment/experiment_tracking.py +319 -0
  21. snowflake/ml/jobs/_utils/constants.py +15 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +156 -54
  23. snowflake/ml/jobs/_utils/query_helper.py +16 -5
  24. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  25. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
  26. snowflake/ml/jobs/_utils/spec_utils.py +23 -8
  27. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  28. snowflake/ml/jobs/_utils/types.py +64 -4
  29. snowflake/ml/jobs/job.py +70 -75
  30. snowflake/ml/jobs/manager.py +59 -31
  31. snowflake/ml/lineage/lineage_node.py +2 -2
  32. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  33. snowflake/ml/model/_client/ops/service_ops.py +336 -137
  34. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  35. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
  36. snowflake/ml/model/_client/sql/service.py +1 -38
  37. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  38. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  41. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  45. snowflake/ml/model/_signatures/utils.py +4 -0
  46. snowflake/ml/model/event_handler.py +117 -0
  47. snowflake/ml/model/model_signature.py +11 -9
  48. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  49. snowflake/ml/modeling/framework/base.py +1 -1
  50. snowflake/ml/modeling/metrics/classification.py +14 -14
  51. snowflake/ml/modeling/metrics/correlation.py +19 -8
  52. snowflake/ml/modeling/metrics/ranking.py +6 -6
  53. snowflake/ml/modeling/metrics/regression.py +9 -9
  54. snowflake/ml/monitoring/explain_visualize.py +12 -5
  55. snowflake/ml/registry/_manager/model_manager.py +32 -15
  56. snowflake/ml/registry/registry.py +48 -80
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
  59. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
  60. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
  61. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
@@ -337,13 +337,54 @@ def get_package_spec_with_supported_ops_only(req: requirements.Requirement) -> r
337
337
  Returns:
338
338
  A requirements.Requirement object with supported ops only
339
339
  """
340
+
341
+ if req.name == "numpy":
342
+ import numpy as np
343
+
344
+ package_specifiers = get_numpy_specifiers(req, version.Version(np.__version__).major)
345
+ else:
346
+ package_specifiers = [spec for spec in req.specifier if spec.operator in _SUPPORTED_PACKAGE_SPEC_OPS]
347
+
340
348
  new_req = copy.deepcopy(req)
341
- new_req.specifier = specifiers.SpecifierSet(
342
- specifiers=",".join([str(spec) for spec in req.specifier if spec.operator in _SUPPORTED_PACKAGE_SPEC_OPS])
343
- )
349
+ new_req.specifier = specifiers.SpecifierSet(specifiers=",".join([str(spec) for spec in package_specifiers]))
344
350
  return new_req
345
351
 
346
352
 
353
+ def get_numpy_specifiers(
354
+ req: requirements.Requirement,
355
+ client_numpy_major_version: int,
356
+ ) -> list[specifiers.Specifier]:
357
+ """Get the package spec with supported ops only including ==, >=, <=, > and < based on the client numpy
358
+ major version.
359
+
360
+ Args:
361
+ req: A requirements.Requirement object showing the requirement.
362
+ client_numpy_major_version: The major version of numpy to be used.
363
+
364
+ Returns:
365
+ A list of specifiers with supported ops only
366
+ """
367
+ req_specifiers = []
368
+ for org_spec in req.specifier:
369
+ # check specifier that provides upper bound
370
+ if org_spec.operator in ["<", "<="]:
371
+ client_version = version.Version(str(client_numpy_major_version))
372
+ org_spec_version = version.Version(org_spec.version)
373
+ # check if the client's numpy major version is less than the specifier's upper bound
374
+ # if so, pin to max possible client major version
375
+ if client_version.major < org_spec_version.major:
376
+ modified_spec = specifiers.Specifier(f"<{client_version.major + 1}")
377
+ req_specifiers.append(modified_spec)
378
+ else:
379
+ # use the original specifier
380
+ req_specifiers.append(org_spec)
381
+ else:
382
+ # use the original specifier
383
+ req_specifiers.append(org_spec)
384
+
385
+ return req_specifiers
386
+
387
+
347
388
  def _relax_specifier_set(
348
389
  specifier_set: specifiers.SpecifierSet, strategy: relax_version_strategy.RelaxVersionStrategy
349
390
  ) -> specifiers.SpecifierSet:
@@ -3,7 +3,9 @@ from contextlib import contextmanager
3
3
  from typing import Any, Optional
4
4
 
5
5
  from absl import logging
6
+ from packaging import version
6
7
 
8
+ from snowflake.ml import version as snowml_version
7
9
  from snowflake.ml._internal.exceptions import error_codes, exceptions
8
10
  from snowflake.ml._internal.utils import query_result_checker
9
11
  from snowflake.snowpark import (
@@ -12,7 +14,7 @@ from snowflake.snowpark import (
12
14
  )
13
15
 
14
16
  LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
15
- INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC"
17
+ INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
16
18
 
17
19
 
18
20
  class PlatformCapabilities:
@@ -67,7 +69,7 @@ class PlatformCapabilities:
67
69
  cls.clear_mock_features()
68
70
 
69
71
  def is_inlined_deployment_spec_enabled(self) -> bool:
70
- return self._get_bool_feature(INLINE_DEPLOYMENT_SPEC_PARAMETER, False)
72
+ return self._is_version_feature_enabled(INLINE_DEPLOYMENT_SPEC_PARAMETER)
71
73
 
72
74
  def is_live_commit_enabled(self) -> bool:
73
75
  return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
@@ -126,3 +128,51 @@ class PlatformCapabilities:
126
128
  else:
127
129
  raise ValueError(f"Invalid boolean string: {value} for feature {feature_name}")
128
130
  raise ValueError(f"Invalid boolean feature value: {value} for feature {feature_name}")
131
+
132
+ def _get_version_feature(self, feature_name: str) -> version.Version:
133
+ """Get a version feature value, returning a large version number on failure or missing feature.
134
+
135
+ Args:
136
+ feature_name: The name of the feature to retrieve.
137
+
138
+ Returns:
139
+ version.Version: The parsed version, or a large version number (999.999.999) if parsing fails
140
+ or the feature is missing.
141
+ """
142
+ # Large version number to use as fallback
143
+ large_version = version.Version("999.999.999")
144
+
145
+ value = self.features.get(feature_name)
146
+ if value is None:
147
+ logging.debug(f"Feature {feature_name} not found, returning large version number")
148
+ return large_version
149
+
150
+ try:
151
+ # Convert to string if it's not already
152
+ version_str = str(value)
153
+ return version.Version(version_str)
154
+ except (version.InvalidVersion, ValueError, TypeError) as e:
155
+ logging.debug(
156
+ f"Failed to parse version from feature {feature_name} with value '{value}': {e}. "
157
+ f"Returning large version number"
158
+ )
159
+ return large_version
160
+
161
+ def _is_version_feature_enabled(self, feature_name: str) -> bool:
162
+ """Check if the current package version is greater than or equal to the version feature.
163
+
164
+ Args:
165
+ feature_name: The name of the version feature to compare against.
166
+
167
+ Returns:
168
+ bool: True if current package version >= feature version, False otherwise.
169
+ """
170
+ current_version = version.Version(snowml_version.VERSION)
171
+ feature_version = self._get_version_feature(feature_name)
172
+
173
+ result = current_version >= feature_version
174
+ logging.debug(
175
+ f"Version comparison for feature {feature_name}: "
176
+ f"current={current_version}, feature={feature_version}, enabled={result}"
177
+ )
178
+ return result
@@ -66,4 +66,4 @@ class LazyType(Generic[T]):
66
66
  return False
67
67
 
68
68
 
69
- LiteralNDArrayType = Union[npt.NDArray[np.int_], npt.NDArray[np.float_], npt.NDArray[np.str_], npt.NDArray[np.bool_]]
69
+ LiteralNDArrayType = Union[npt.NDArray[np.int_], npt.NDArray[np.float64], npt.NDArray[np.str_], npt.NDArray[np.bool_]]
@@ -1,7 +1,19 @@
1
1
  from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import identifier
4
- from snowflake.snowpark import session
4
+ from snowflake.snowpark import session as snowpark_session
5
+
6
+ _SESSION_KEY = "_session"
7
+ _SESSION_ACCOUNT_KEY = "session$account"
8
+ _SESSION_ROLE_KEY = "session$role"
9
+ _SESSION_DATABASE_KEY = "session$database"
10
+ _SESSION_SCHEMA_KEY = "session$schema"
11
+
12
+
13
+ def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
14
+ saved_resolved = identifier.resolve_identifier(saved) if saved is not None else saved
15
+ current_resolved = identifier.resolve_identifier(current) if current is not None else current
16
+ return saved_resolved == current_resolved
5
17
 
6
18
 
7
19
  class SerializableSessionMixin:
@@ -9,53 +21,53 @@ class SerializableSessionMixin:
9
21
 
10
22
  def __getstate__(self) -> dict[str, Any]:
11
23
  """Customize pickling to exclude non-serializable session and related components."""
12
- state = self.__dict__.copy()
24
+ parent_state = (
25
+ super().__getstate__() # type: ignore[misc] # object.__getstate__ appears in 3.11
26
+ if hasattr(super(), "__getstate__")
27
+ else self.__dict__
28
+ )
29
+ state = dict(parent_state) # Create a copy so we can safely modify the state
13
30
 
14
31
  # Save session metadata for validation during unpickling
15
- if hasattr(self, "_session") and self._session is not None:
16
- try:
17
- state["__session-account__"] = self._session.get_current_account()
18
- state["__session-role__"] = self._session.get_current_role()
19
- state["__session-database__"] = self._session.get_current_database()
20
- state["__session-schema__"] = self._session.get_current_schema()
21
- except Exception:
22
- pass
23
-
24
- state["_session"] = None
32
+ session = state.pop(_SESSION_KEY, None)
33
+ if session is not None:
34
+ state[_SESSION_ACCOUNT_KEY] = session.get_current_account()
35
+ state[_SESSION_ROLE_KEY] = session.get_current_role()
36
+ state[_SESSION_DATABASE_KEY] = session.get_current_database()
37
+ state[_SESSION_SCHEMA_KEY] = session.get_current_schema()
38
+
25
39
  return state
26
40
 
27
41
  def __setstate__(self, state: dict[str, Any]) -> None:
28
42
  """Restore session from context during unpickling."""
29
- saved_account = state.pop("__session-account__", None)
30
- saved_role = state.pop("__session-role__", None)
31
- saved_database = state.pop("__session-database__", None)
32
- saved_schema = state.pop("__session-schema__", None)
33
- self.__dict__.update(state)
43
+ saved_account = state.pop(_SESSION_ACCOUNT_KEY, None)
44
+ saved_role = state.pop(_SESSION_ROLE_KEY, None)
45
+ saved_database = state.pop(_SESSION_DATABASE_KEY, None)
46
+ saved_schema = state.pop(_SESSION_SCHEMA_KEY, None)
47
+
48
+ if hasattr(super(), "__setstate__"):
49
+ super().__setstate__(state) # type: ignore[misc]
50
+ else:
51
+ self.__dict__.update(state)
34
52
 
35
53
  if saved_account is not None:
54
+ active_sessions = snowpark_session._get_active_sessions()
55
+ if len(active_sessions) == 0:
56
+ raise RuntimeError("No active Snowpark session available. Please create a session.")
36
57
 
37
- def identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
38
- saved_resolved = identifier.resolve_identifier(saved) if saved is not None else saved
39
- current_resolved = identifier.resolve_identifier(current) if current is not None else current
40
- return saved_resolved == current_resolved
41
-
42
- for active_session in session._get_active_sessions():
43
- try:
44
- current_account = active_session.get_current_account()
45
- current_role = active_session.get_current_role()
46
- current_database = active_session.get_current_database()
47
- current_schema = active_session.get_current_schema()
48
-
49
- if (
50
- identifiers_match(saved_account, current_account)
51
- and identifiers_match(saved_role, current_role)
52
- and identifiers_match(saved_database, current_database)
53
- and identifiers_match(saved_schema, current_schema)
54
- ):
55
- self._session = active_session
56
- return
57
- except Exception:
58
- continue
59
-
60
- # No matching session found or no metadata available
61
- raise RuntimeError("No active Snowpark session available. Please create a session.")
58
+ # Best effort match: Find the session with the most matching identifiers
59
+ setattr(
60
+ self,
61
+ _SESSION_KEY,
62
+ max(
63
+ active_sessions,
64
+ key=lambda s: sum(
65
+ (
66
+ _identifiers_match(saved_account, s.get_current_account()),
67
+ _identifiers_match(saved_role, s.get_current_role()),
68
+ _identifiers_match(saved_database, s.get_current_database()),
69
+ _identifiers_match(saved_schema, s.get_current_schema()),
70
+ )
71
+ ),
72
+ ),
73
+ )
@@ -1,6 +1,13 @@
1
1
  import enum
2
2
  import logging
3
+ import os
3
4
  import sys
5
+ import tempfile
6
+ import time
7
+ import uuid
8
+ from typing import Optional
9
+
10
+ import platformdirs
4
11
 
5
12
 
6
13
  class LogColor(enum.Enum):
@@ -10,6 +17,10 @@ class LogColor(enum.Enum):
10
17
  YELLOW = "\x1b[33;20m"
11
18
  BLUE = "\x1b[34;20m"
12
19
  GREEN = "\x1b[32;20m"
20
+ ORANGE = "\x1b[38;5;214m"
21
+ BOLD_ORANGE = "\x1b[38;5;214;1m"
22
+ PURPLE = "\x1b[35;20m"
23
+ BOLD_PURPLE = "\x1b[35;1m"
13
24
 
14
25
 
15
26
  class CustomFormatter(logging.Formatter):
@@ -53,11 +64,102 @@ class CustomFormatter(logging.Formatter):
53
64
  return "\n".join(formatted_lines)
54
65
 
55
66
 
56
- def get_logger(logger_name: str, info_color: LogColor) -> logging.Logger:
67
+ def _test_writability(directory: str) -> bool:
68
+ """Test if a directory is writable by creating and removing a test file."""
69
+ try:
70
+ os.makedirs(directory, exist_ok=True)
71
+ test_file = os.path.join(directory, f".write_test_{uuid.uuid4().hex[:8]}")
72
+ with open(test_file, "w") as f:
73
+ f.write("test")
74
+ os.remove(test_file)
75
+ return True
76
+ except OSError:
77
+ return False
78
+
79
+
80
+ def _try_log_location(log_dir: str, operation_id: str) -> Optional[str]:
81
+ """Try to create a log file in the given directory if it's writable."""
82
+ if _test_writability(log_dir):
83
+ return os.path.join(log_dir, f"{operation_id}.log")
84
+ return None
85
+
86
+
87
+ def _get_log_file_path(operation_id: str) -> Optional[str]:
88
+ """Get platform-independent log file path. Returns None if no writable location found."""
89
+ # Try locations in order of preference
90
+ locations = [
91
+ # Primary: User log directory
92
+ platformdirs.user_log_dir("snowflake-ml", "Snowflake"),
93
+ # Fallback 1: System temp directory
94
+ os.path.join(tempfile.gettempdir(), "snowflake-ml-logs"),
95
+ # Fallback 2: Current working directory
96
+ ".",
97
+ ]
98
+
99
+ for location in locations:
100
+ log_file_path = _try_log_location(location, operation_id)
101
+ if log_file_path:
102
+ return log_file_path
103
+
104
+ # No writable location found
105
+ return None
106
+
107
+
108
+ def _get_or_create_parent_logger(operation_id: str) -> logging.Logger:
109
+ """Get or create a parent logger with FileHandler for the operation."""
110
+ parent_logger_name = f"snowflake_ml_operation_{operation_id}"
111
+ parent_logger = logging.getLogger(parent_logger_name)
112
+
113
+ # Only add handler if it doesn't exist yet
114
+ if not parent_logger.handlers:
115
+ log_file_path = _get_log_file_path(operation_id)
116
+
117
+ if log_file_path:
118
+ # Successfully found a writable location
119
+ try:
120
+ file_handler = logging.FileHandler(log_file_path)
121
+ file_handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
122
+ parent_logger.addHandler(file_handler)
123
+ parent_logger.setLevel(logging.DEBUG)
124
+ parent_logger.propagate = False # Don't propagate to root logger
125
+
126
+ # Log the file location
127
+ parent_logger.warning(f"Operation logs saved to: {log_file_path}")
128
+ except OSError as e:
129
+ # Even though we found a path, file creation failed
130
+ # Fall back to console-only logging
131
+ parent_logger.setLevel(logging.DEBUG)
132
+ parent_logger.propagate = False
133
+ parent_logger.warning(f"Could not create log file at {log_file_path}: {e}. Using console-only logging.")
134
+ else:
135
+ # No writable location found, use console-only logging
136
+ parent_logger.setLevel(logging.DEBUG)
137
+ parent_logger.propagate = False
138
+ parent_logger.warning("Filesystem appears to be readonly. Using console-only logging.")
139
+
140
+ return parent_logger
141
+
142
+
143
+ def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[str] = None) -> logging.Logger:
57
144
  logger = logging.getLogger(logger_name)
58
- logger.setLevel(logging.INFO)
59
145
  handler = logging.StreamHandler(sys.stdout)
60
- handler.setLevel(logging.INFO)
61
146
  handler.setFormatter(CustomFormatter(info_color))
62
147
  logger.addHandler(handler)
148
+
149
+ # If operation_id provided, set up parent logger with file handler
150
+ if operation_id:
151
+ parent_logger = _get_or_create_parent_logger(operation_id)
152
+ logger.parent = parent_logger
153
+ logger.propagate = True
154
+
63
155
  return logger
156
+
157
+
158
+ def get_operation_id() -> str:
159
+ """Generate a unique operation ID."""
160
+ return f"model_deploy_{uuid.uuid4().hex[:8]}_{int(time.time())}"
161
+
162
+
163
+ def get_log_file_location(operation_id: str) -> Optional[str]:
164
+ """Get the log file path for an operation ID. Returns None if no writable location available."""
165
+ return _get_log_file_path(operation_id)
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import os
4
4
  import time
5
- from typing import Any, Deque, Iterator, Optional, Sequence, Union
5
+ from typing import TYPE_CHECKING, Any, Deque, Iterator, Optional, Sequence, Union
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
@@ -10,7 +10,11 @@ import pandas as pd
10
10
  import pyarrow as pa
11
11
  import pyarrow.dataset as pds
12
12
 
13
+ if TYPE_CHECKING:
14
+ import ray
15
+
13
16
  from snowflake import snowpark
17
+ from snowflake.ml._internal.utils import mixins
14
18
  from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
15
19
 
16
20
  _EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
@@ -41,7 +45,7 @@ class _RecordBatchesBuffer:
41
45
  return popped
42
46
 
43
47
 
44
- class ArrowIngestor(data_ingestor.DataIngestor):
48
+ class ArrowIngestor(data_ingestor.DataIngestor, mixins.SerializableSessionMixin):
45
49
  """Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
46
50
 
47
51
  def __init__(
@@ -68,8 +72,17 @@ class ArrowIngestor(data_ingestor.DataIngestor):
68
72
 
69
73
  @classmethod
70
74
  def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
75
+ if session is None:
76
+ raise ValueError("Session is required")
71
77
  return cls(session, sources)
72
78
 
79
+ @classmethod
80
+ def from_ray_dataset(
81
+ cls,
82
+ ray_ds: "ray.data.Dataset",
83
+ ) -> "ArrowIngestor":
84
+ raise NotImplementedError
85
+
73
86
  @property
74
87
  def data_sources(self) -> list[data_source.DataSource]:
75
88
  return self._data_sources
@@ -8,7 +8,7 @@ from snowflake import snowpark
8
8
  from snowflake.ml._internal import env, telemetry
9
9
  from snowflake.ml.data import data_ingestor, data_source
10
10
  from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
11
- from snowflake.snowpark import context as sf_context
11
+ from snowflake.snowpark import context as sp_context
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  import pandas as pd
@@ -57,7 +57,7 @@ class DataConnector:
57
57
  ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
58
58
  **kwargs: Any,
59
59
  ) -> DataConnectorType:
60
- session = session or sf_context.get_active_session()
60
+ session = session or sp_context.get_active_session()
61
61
  source = data_source.DataFrameInfo(query)
62
62
  return cls.from_sources(session, [source], ingestor_class=ingestor_class, **kwargs)
63
63
 
@@ -75,6 +75,17 @@ class DataConnector:
75
75
  )
76
76
  return cls.from_sources(ds._session, [source], ingestor_class=ingestor_class, **kwargs)
77
77
 
78
+ @classmethod
79
+ def from_ray_dataset(
80
+ cls: type[DataConnectorType],
81
+ ray_ds: "ray.data.Dataset",
82
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
83
+ **kwargs: Any,
84
+ ) -> DataConnectorType:
85
+ ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
86
+ ray_ingestor = ingestor_class.from_ray_dataset(ray_ds=ray_ds)
87
+ return cls(ray_ingestor, **kwargs)
88
+
78
89
  @classmethod
79
90
  @telemetry.send_api_usage_telemetry(
80
91
  project=_PROJECT,
@@ -7,6 +7,7 @@ from snowflake.ml.data import data_source
7
7
 
8
8
  if TYPE_CHECKING:
9
9
  import pandas as pd
10
+ import ray
10
11
 
11
12
 
12
13
  DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
@@ -19,6 +20,13 @@ class DataIngestor(Protocol):
19
20
  ) -> DataIngestorType:
20
21
  raise NotImplementedError
21
22
 
23
+ @classmethod
24
+ def from_ray_dataset(
25
+ cls: type[DataIngestorType],
26
+ ray_ds: "ray.data.Dataset",
27
+ ) -> DataIngestorType:
28
+ raise NotImplementedError
29
+
22
30
  @property
23
31
  def data_sources(self) -> list[data_source.DataSource]:
24
32
  raise NotImplementedError
@@ -95,6 +95,6 @@ def _preprocess_array(
95
95
  array_list = arr.tolist()
96
96
  # If this is an array of arrays, convert the dtype to match the underlying array.
97
97
  # Otherwise, if this is a numpy array of strings, convert the array to a list.
98
- arr = np.array(array_list, dtype=arr.flat[0].dtype) if isinstance(arr.flat[0], np.ndarray) else array_list
98
+ arr = np.array(array_list, dtype=arr.item(0).dtype) if isinstance(arr.item(0), np.ndarray) else array_list
99
99
 
100
100
  return arr
@@ -14,6 +14,7 @@ from snowflake.ml._internal.exceptions import (
14
14
  from snowflake.ml._internal.utils import (
15
15
  formatting,
16
16
  identifier,
17
+ mixins,
17
18
  query_result_checker,
18
19
  snowpark_dataframe_utils,
19
20
  )
@@ -27,7 +28,7 @@ _METADATA_MAX_QUERY_LENGTH = 10000
27
28
  _DATASET_VERSION_NAME_COL = "version"
28
29
 
29
30
 
30
- class DatasetVersion:
31
+ class DatasetVersion(mixins.SerializableSessionMixin):
31
32
  """Represents a version of a Snowflake Dataset"""
32
33
 
33
34
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
@@ -3,6 +3,7 @@ from typing import Any, Optional
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal import telemetry
5
5
  from snowflake.ml._internal.lineage import lineage_utils
6
+ from snowflake.ml._internal.utils import mixins
6
7
  from snowflake.ml.data import data_connector, data_ingestor, data_source, ingestor_utils
7
8
  from snowflake.ml.fileset import snowfs
8
9
  from snowflake.snowpark._internal import utils as snowpark_utils
@@ -11,7 +12,7 @@ _PROJECT = "Dataset"
11
12
  _SUBPROJECT = "DatasetReader"
12
13
 
13
14
 
14
- class DatasetReader(data_connector.DataConnector):
15
+ class DatasetReader(data_connector.DataConnector, mixins.SerializableSessionMixin):
15
16
  """Snowflake Dataset abstraction which provides application integration connectors"""
16
17
 
17
18
  @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
@@ -19,14 +20,21 @@ class DatasetReader(data_connector.DataConnector):
19
20
  self,
20
21
  ingestor: data_ingestor.DataIngestor,
21
22
  *,
22
- snowpark_session: snowpark.Session,
23
+ snowpark_session: Optional[snowpark.Session] = None,
23
24
  ) -> None:
24
25
  super().__init__(ingestor)
25
26
 
26
- self._session: snowpark.Session = snowpark_session
27
- self._fs: snowfs.SnowFileSystem = ingestor_utils.get_dataset_filesystem(self._session)
27
+ self._session = snowpark_session
28
+ self._fs_cached: Optional[snowfs.SnowFileSystem] = None
28
29
  self._files: Optional[list[str]] = None
29
30
 
31
+ @property
32
+ def _fs(self) -> snowfs.SnowFileSystem:
33
+ if self._fs_cached is None:
34
+ assert self._session is not None
35
+ self._fs_cached = ingestor_utils.get_dataset_filesystem(self._session)
36
+ return self._fs_cached
37
+
30
38
  @classmethod
31
39
  def from_dataframe(
32
40
  cls, df: snowpark.DataFrame, ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None, **kwargs: Any
@@ -42,6 +50,7 @@ class DatasetReader(data_connector.DataConnector):
42
50
  files: list[str] = []
43
51
  for source in self.data_sources:
44
52
  assert isinstance(source, data_source.DatasetInfo)
53
+ assert self._session is not None
45
54
  files.extend(ingestor_utils.get_dataset_files(self._session, source, filesystem=self._fs))
46
55
  files.sort()
47
56
 
@@ -95,6 +104,7 @@ class DatasetReader(data_connector.DataConnector):
95
104
  dfs: list[snowpark.DataFrame] = []
96
105
  for source in self.data_sources:
97
106
  assert isinstance(source, data_source.DatasetInfo) and source.url is not None
107
+ assert self._session is not None
98
108
  stage_reader = self._session.read.option("pattern", file_path_pattern)
99
109
  if "INFER_SCHEMA_OPTIONS" in snowpark_utils.NON_FORMAT_TYPE_OPTIONS:
100
110
  stage_reader = stage_reader.option("INFER_SCHEMA_OPTIONS", {"MAX_FILE_COUNT": 1})
@@ -0,0 +1,3 @@
1
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
2
+
3
+ __all__ = ["ExperimentTracking"]