snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional, Tuple
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal.utils import formatting, identifier, query_result_checker
@@ -24,8 +24,8 @@ def create_single_table(
24
24
  database_name: str,
25
25
  schema_name: str,
26
26
  table_name: str,
27
- table_schema: List[Tuple[str, str]],
28
- statement_params: Optional[Dict[str, Any]] = None,
27
+ table_schema: list[tuple[str, str]],
28
+ statement_params: Optional[dict[str, Any]] = None,
29
29
  ) -> str:
30
30
  """Creates a single table for registry and returns the fully qualified name of the table.
31
31
 
@@ -55,7 +55,7 @@ def create_single_table(
55
55
  return fully_qualified_table_name
56
56
 
57
57
 
58
- def insert_table_entry(session: snowpark.Session, table: str, columns: Dict[str, Any]) -> List[snowpark.Row]:
58
+ def insert_table_entry(session: snowpark.Session, table: str, columns: dict[str, Any]) -> list[snowpark.Row]:
59
59
  """Insert an entry into an internal Model Registry table.
60
60
 
61
61
  Args:
@@ -99,9 +99,9 @@ def validate_table_exist(session: snowpark.Session, table: str, qualified_schema
99
99
  return len(tables) == 1
100
100
 
101
101
 
102
- def get_table_schema(session: snowpark.Session, table_name: str, qualified_schema_name: str) -> Dict[str, str]:
102
+ def get_table_schema(session: snowpark.Session, table_name: str, qualified_schema_name: str) -> dict[str, str]:
103
103
  result = session.sql(f"DESC TABLE {qualified_schema_name}.{table_name}").collect()
104
- schema_dict: Dict[str, str] = {}
104
+ schema_dict: dict[str, str] = {}
105
105
  for row in result:
106
106
  schema_dict[row["name"]] = row["type"]
107
107
  return schema_dict
@@ -112,13 +112,13 @@ def get_table_schema_types(
112
112
  database: str,
113
113
  schema: str,
114
114
  table_name: str,
115
- ) -> Dict[str, types.DataType]:
115
+ ) -> dict[str, types.DataType]:
116
116
  fully_qualified_table_name = identifier.get_schema_level_object_identifier(
117
117
  db=database, schema=schema, object_name=table_name
118
118
  )
119
- struct_fields: List[types.StructField] = session.table(fully_qualified_table_name).schema.fields
119
+ struct_fields: list[types.StructField] = session.table(fully_qualified_table_name).schema.fields
120
120
 
121
- schema_dict: Dict[str, types.DataType] = {}
121
+ schema_dict: dict[str, types.DataType] = {}
122
122
  for field in struct_fields:
123
123
  schema_dict[field.name] = field.datatype
124
124
  return schema_dict
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import os
4
4
  import time
5
- from typing import Any, Deque, Dict, Iterator, List, Optional, Sequence, Union
5
+ from typing import Any, Deque, Iterator, Optional, Sequence, Union
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
@@ -71,7 +71,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
71
71
  return cls(session, sources)
72
72
 
73
73
  @property
74
- def data_sources(self) -> List[data_source.DataSource]:
74
+ def data_sources(self) -> list[data_source.DataSource]:
75
75
  return self._data_sources
76
76
 
77
77
  def to_batches(
@@ -79,7 +79,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
79
79
  batch_size: int,
80
80
  shuffle: bool = True,
81
81
  drop_last_batch: bool = True,
82
- ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
82
+ ) -> Iterator[dict[str, npt.NDArray[Any]]]:
83
83
  """Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
84
84
 
85
85
  As we are generating batches with the exactly same length, the last few rows in each file might get left as they
@@ -120,7 +120,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
120
120
 
121
121
  def _get_dataset(self, shuffle: bool) -> pds.Dataset:
122
122
  format = self._format
123
- sources: List[Any] = []
123
+ sources: list[Any] = []
124
124
  source_format = None
125
125
  for source in self._data_sources:
126
126
  if isinstance(source, str):
@@ -155,7 +155,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
155
155
  pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
156
156
  return pa_dataset
157
157
 
158
- def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
158
+ def _get_batches_from_buffer(self, batch_size: int) -> dict[str, npt.NDArray[Any]]:
159
159
  """Generate new batches from the existing record batch buffer."""
160
160
  cnt_rbs_num_rows = 0
161
161
  candidates = []
@@ -180,7 +180,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
180
180
  return _record_batch_to_arrays(res)
181
181
 
182
182
 
183
- def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
183
+ def _merge_record_batches(record_batches: list[pa.RecordBatch]) -> pa.RecordBatch:
184
184
  """Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
185
185
  if not record_batches:
186
186
  return _EMPTY_RECORD_BATCH
@@ -192,7 +192,7 @@ def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatc
192
192
  return batches[0]
193
193
 
194
194
 
195
- def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
195
+ def _record_batch_to_arrays(rb: pa.RecordBatch) -> dict[str, npt.NDArray[Any]]:
196
196
  """Transform the record batch to a (string, numpy array) dict."""
197
197
  batch_dict = {}
198
198
  for column, column_schema in zip(rb, rb.schema):
@@ -1,32 +1,18 @@
1
1
  import os
2
- from typing import (
3
- TYPE_CHECKING,
4
- Any,
5
- Dict,
6
- Generator,
7
- List,
8
- Optional,
9
- Sequence,
10
- Type,
11
- TypeVar,
12
- cast,
13
- )
2
+ from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence, TypeVar
14
3
 
15
4
  import numpy.typing as npt
16
5
  from typing_extensions import deprecated
17
6
 
18
7
  from snowflake import snowpark
19
- from snowflake.ml._internal import telemetry
8
+ from snowflake.ml._internal import env, telemetry
20
9
  from snowflake.ml.data import data_ingestor, data_source
21
10
  from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
22
- from snowflake.ml.modeling._internal.constants import (
23
- IN_ML_RUNTIME_ENV_VAR,
24
- USE_OPTIMIZED_DATA_INGESTOR,
25
- )
26
11
  from snowflake.snowpark import context as sf_context
27
12
 
28
13
  if TYPE_CHECKING:
29
14
  import pandas as pd
15
+ import ray
30
16
  import tensorflow as tf
31
17
  from torch.utils import data as torch_data
32
18
 
@@ -42,7 +28,7 @@ DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
42
28
  class DataConnector:
43
29
  """Snowflake data reader which provides application integration connectors"""
44
30
 
45
- DEFAULT_INGESTOR_CLASS: Type[data_ingestor.DataIngestor] = ArrowIngestor
31
+ DEFAULT_INGESTOR_CLASS: type[data_ingestor.DataIngestor] = ArrowIngestor
46
32
 
47
33
  def __init__(
48
34
  self,
@@ -53,27 +39,22 @@ class DataConnector:
53
39
  self._kwargs = kwargs
54
40
 
55
41
  @classmethod
56
- @snowpark._internal.utils.private_preview(version="1.6.0")
57
42
  def from_dataframe(
58
- cls: Type[DataConnectorType],
43
+ cls: type[DataConnectorType],
59
44
  df: snowpark.DataFrame,
60
- ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
45
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
61
46
  **kwargs: Any,
62
47
  ) -> DataConnectorType:
63
48
  if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
64
49
  raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
65
- return cast(
66
- DataConnectorType,
67
- cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs),
68
- )
50
+ return cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs)
69
51
 
70
52
  @classmethod
71
- @snowpark._internal.utils.private_preview(version="1.7.3")
72
53
  def from_sql(
73
- cls: Type[DataConnectorType],
54
+ cls: type[DataConnectorType],
74
55
  query: str,
75
56
  session: Optional[snowpark.Session] = None,
76
- ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
57
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
77
58
  **kwargs: Any,
78
59
  ) -> DataConnectorType:
79
60
  session = session or sf_context.get_active_session()
@@ -82,9 +63,9 @@ class DataConnector:
82
63
 
83
64
  @classmethod
84
65
  def from_dataset(
85
- cls: Type[DataConnectorType],
66
+ cls: type[DataConnectorType],
86
67
  ds: "dataset.Dataset",
87
- ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
68
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
88
69
  **kwargs: Any,
89
70
  ) -> DataConnectorType:
90
71
  dsv = ds.selected_version
@@ -101,10 +82,10 @@ class DataConnector:
101
82
  func_params_to_log=["sources", "ingestor_class"],
102
83
  )
103
84
  def from_sources(
104
- cls: Type[DataConnectorType],
85
+ cls: type[DataConnectorType],
105
86
  session: snowpark.Session,
106
87
  sources: Sequence[data_source.DataSource],
107
- ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
88
+ ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
108
89
  **kwargs: Any,
109
90
  ) -> DataConnectorType:
110
91
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
@@ -112,7 +93,7 @@ class DataConnector:
112
93
  return cls(ingestor, **kwargs)
113
94
 
114
95
  @property
115
- def data_sources(self) -> List[data_source.DataSource]:
96
+ def data_sources(self) -> list[data_source.DataSource]:
116
97
  return self._ingestor.data_sources
117
98
 
118
99
  @telemetry.send_api_usage_telemetry(
@@ -138,7 +119,7 @@ class DataConnector:
138
119
  """
139
120
  import tensorflow as tf
140
121
 
141
- def generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]:
122
+ def generator() -> Generator[dict[str, npt.NDArray[Any]], None, None]:
142
123
  yield from self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
143
124
 
144
125
  # Derive TensorFlow signature
@@ -241,14 +222,37 @@ class DataConnector:
241
222
  """
242
223
  return self._ingestor.to_pandas(limit)
243
224
 
225
+ @telemetry.send_api_usage_telemetry(
226
+ project=_PROJECT,
227
+ subproject_extractor=lambda self: type(self).__name__,
228
+ func_params_to_log=["limit"],
229
+ )
230
+ def to_ray_dataset(self) -> "ray.data.Dataset":
231
+ """Retrieve the Snowflake data as a Ray Dataset.
232
+
233
+ Returns:
234
+ A Ray Dataset.
235
+
236
+ Raises:
237
+ ImportError: If Ray is not installed in the local environment.
238
+ """
239
+ if hasattr(self._ingestor, "to_ray_dataset"):
240
+ return self._ingestor.to_ray_dataset()
241
+
242
+ try:
243
+ import ray
244
+
245
+ return ray.data.from_pandas(self._ingestor.to_pandas())
246
+ except ImportError as e:
247
+ raise ImportError("Ray is not installed, please install ray in your local environment.") from e
248
+
244
249
 
245
250
  # Switch to use Runtime's Data Ingester if running in ML runtime
246
251
  # Fail silently if the data ingester is not found
247
- if os.getenv(IN_ML_RUNTIME_ENV_VAR) and os.getenv(USE_OPTIMIZED_DATA_INGESTOR):
252
+ if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR):
248
253
  try:
249
254
  from runtime_external_entities import get_ingester_class
250
255
 
251
256
  DataConnector.DEFAULT_INGESTOR_CLASS = get_ingester_class()
252
257
  except ImportError:
253
258
  """Runtime Default Ingester not found, ignore"""
254
- pass
@@ -1,15 +1,4 @@
1
- from typing import (
2
- TYPE_CHECKING,
3
- Any,
4
- Dict,
5
- Iterator,
6
- List,
7
- Optional,
8
- Protocol,
9
- Sequence,
10
- Type,
11
- TypeVar,
12
- )
1
+ from typing import TYPE_CHECKING, Any, Iterator, Optional, Protocol, Sequence, TypeVar
13
2
 
14
3
  from numpy import typing as npt
15
4
 
@@ -26,12 +15,12 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
26
15
  class DataIngestor(Protocol):
27
16
  @classmethod
28
17
  def from_sources(
29
- cls: Type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
18
+ cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
30
19
  ) -> DataIngestorType:
31
20
  raise NotImplementedError
32
21
 
33
22
  @property
34
- def data_sources(self) -> List[data_source.DataSource]:
23
+ def data_sources(self) -> list[data_source.DataSource]:
35
24
  raise NotImplementedError
36
25
 
37
26
  def to_batches(
@@ -39,7 +28,7 @@ class DataIngestor(Protocol):
39
28
  batch_size: int,
40
29
  shuffle: bool = True,
41
30
  drop_last_batch: bool = True,
42
- ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
31
+ ) -> Iterator[dict[str, npt.NDArray[Any]]]:
43
32
  raise NotImplementedError
44
33
 
45
34
  def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
@@ -1,5 +1,5 @@
1
1
  import dataclasses
2
- from typing import List, Optional, Union
2
+ from typing import Optional, Union
3
3
 
4
4
 
5
5
  @dataclasses.dataclass(frozen=True)
@@ -17,7 +17,7 @@ class DatasetInfo:
17
17
  fully_qualified_name: str
18
18
  version: str
19
19
  url: Optional[str] = None
20
- exclude_cols: Optional[List[str]] = None
20
+ exclude_cols: Optional[list[str]] = None
21
21
 
22
22
 
23
23
  DataSource = Union[DataFrameInfo, DatasetInfo, str]
@@ -1,4 +1,4 @@
1
- from typing import List, Optional
1
+ from typing import Optional
2
2
 
3
3
  import fsspec
4
4
  import pyarrow as pa
@@ -33,7 +33,7 @@ def _get_dataframe_cursor(session: snowpark.Session, df_info: data_source.DataFr
33
33
 
34
34
  def get_dataframe_result_batches(
35
35
  session: snowpark.Session, df_info: data_source.DataFrameInfo
36
- ) -> List[result_batch.ResultBatch]:
36
+ ) -> list[result_batch.ResultBatch]:
37
37
  """Retrieve the ResultBatches for a given query"""
38
38
  cursor = _get_dataframe_cursor(session, df_info)
39
39
  batches = cursor.get_result_batches()
@@ -63,7 +63,7 @@ def get_dataset_filesystem(
63
63
 
64
64
  def get_dataset_files(
65
65
  session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
66
- ) -> List[str]:
66
+ ) -> list[str]:
67
67
  """Get the list of files in a given Dataset"""
68
68
  if filesystem is None:
69
69
  filesystem = get_dataset_filesystem(session, ds_info)
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Iterator, List, Optional, Union
1
+ from typing import Any, Iterator, Optional, Union
2
2
 
3
3
  import numpy as np
4
4
  import numpy.typing as npt
@@ -7,7 +7,7 @@ import torch.utils.data
7
7
  from snowflake.ml.data import data_ingestor
8
8
 
9
9
 
10
- class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
10
+ class TorchDatasetWrapper(torch.utils.data.IterableDataset[dict[str, Any]]):
11
11
  """Wrap a DataIngestor into a PyTorch IterableDataset"""
12
12
 
13
13
  def __init__(
@@ -32,7 +32,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
32
32
  self._squeeze_outputs = squeeze
33
33
  self._expand_dims = expand_dims
34
34
 
35
- def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]:
35
+ def __iter__(self) -> Iterator[dict[str, Union[npt.NDArray[Any], list[Any]]]]:
36
36
  max_idx = 0
37
37
  filter_idx = 0
38
38
  worker_info = torch.utils.data.get_worker_info()
@@ -59,7 +59,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
59
59
  counter = 0
60
60
 
61
61
 
62
- class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Dict[str, Any]]):
62
+ class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[dict[str, Any]]):
63
63
  """Wrap a DataIngestor into a PyTorch IterDataPipe"""
64
64
 
65
65
  def __init__(
@@ -77,7 +77,7 @@ class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Di
77
77
 
78
78
  def _preprocess_array(
79
79
  arr: npt.NDArray[Any], squeeze: bool = False, expand_dims: bool = True
80
- ) -> Union[npt.NDArray[Any], List[np.object_]]:
80
+ ) -> Union[npt.NDArray[Any], list[np.object_]]:
81
81
  """Preprocesses batch column values."""
82
82
  single_dimensional = arr.ndim < 2 and not arr.dtype == np.object_
83
83
 
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import warnings
3
3
  from datetime import datetime
4
- from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from typing import Any, Optional, Union
5
5
 
6
6
  from snowflake import snowpark
7
7
  from snowflake.ml._internal import telemetry
@@ -46,8 +46,8 @@ class DatasetVersion:
46
46
  self._version = version
47
47
  self._session: snowpark.Session = self._parent._session
48
48
 
49
- self._properties: Optional[Dict[str, Any]] = None
50
- self._raw_metadata: Optional[Dict[str, Any]] = None
49
+ self._properties: Optional[dict[str, Any]] = None
50
+ self._raw_metadata: Optional[dict[str, Any]] = None
51
51
  self._metadata: Optional[dataset_metadata.DatasetMetadata] = None
52
52
 
53
53
  @property
@@ -66,14 +66,14 @@ class DatasetVersion:
66
66
  return comment
67
67
 
68
68
  @property
69
- def label_cols(self) -> List[str]:
69
+ def label_cols(self) -> list[str]:
70
70
  metadata = self._get_metadata()
71
71
  if metadata is None or metadata.label_cols is None:
72
72
  return []
73
73
  return metadata.label_cols
74
74
 
75
75
  @property
76
- def exclude_cols(self) -> List[str]:
76
+ def exclude_cols(self) -> list[str]:
77
77
  metadata = self._get_metadata()
78
78
  if metadata is None or metadata.exclude_cols is None:
79
79
  return []
@@ -115,7 +115,7 @@ class DatasetVersion:
115
115
  return path
116
116
 
117
117
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
118
- def list_files(self, subdir: Optional[str] = None) -> List[snowpark.Row]:
118
+ def list_files(self, subdir: Optional[str] = None) -> list[snowpark.Row]:
119
119
  """Get the list of remote file paths for the current DatasetVersion."""
120
120
  return self._session.sql(f"LIST {self.url()}{subdir or ''}").collect(
121
121
  statement_params=_TELEMETRY_STATEMENT_PARAMS
@@ -244,7 +244,7 @@ class Dataset(lineage_node.LineageNode):
244
244
  raise
245
245
 
246
246
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
247
- def list_versions(self, detailed: bool = False) -> Union[List[str], List[snowpark.Row]]:
247
+ def list_versions(self, detailed: bool = False) -> Union[list[str], list[snowpark.Row]]:
248
248
  """Return list of versions"""
249
249
  versions = self._list_versions()
250
250
  versions.sort(key=lambda r: r[_DATASET_VERSION_NAME_COL])
@@ -271,8 +271,8 @@ class Dataset(lineage_node.LineageNode):
271
271
  version: str,
272
272
  input_dataframe: snowpark.DataFrame,
273
273
  shuffle: bool = False,
274
- exclude_cols: Optional[List[str]] = None,
275
- label_cols: Optional[List[str]] = None,
274
+ exclude_cols: Optional[list[str]] = None,
275
+ label_cols: Optional[list[str]] = None,
276
276
  properties: Optional[dataset_metadata.DatasetPropertiesType] = None,
277
277
  partition_by: Optional[str] = None,
278
278
  comment: Optional[str] = None,
@@ -423,7 +423,7 @@ class Dataset(lineage_node.LineageNode):
423
423
  statement_params=_TELEMETRY_STATEMENT_PARAMS
424
424
  )
425
425
 
426
- def _list_versions(self, pattern: Optional[str] = None) -> List[snowpark.Row]:
426
+ def _list_versions(self, pattern: Optional[str] = None) -> list[snowpark.Row]:
427
427
  """Return list of versions"""
428
428
  try:
429
429
  pattern_clause = f" LIKE '{pattern}'" if pattern else ""
@@ -469,7 +469,7 @@ lineage_node.DOMAIN_LINEAGE_REGISTRY["dataset"] = Dataset
469
469
  # Utility methods
470
470
 
471
471
 
472
- def _get_schema_level_identifier(session: snowpark.Session, dataset_name: str) -> Tuple[str, str, str]:
472
+ def _get_schema_level_identifier(session: snowpark.Session, dataset_name: str) -> tuple[str, str, str]:
473
473
  """Resolve a dataset name into a validated schema-level location identifier"""
474
474
  db, schema, object_name = identifier.parse_schema_level_object_identifier(dataset_name)
475
475
  db = db or session.get_current_database()
@@ -1,7 +1,7 @@
1
1
  import dataclasses
2
2
  import json
3
3
  import typing
4
- from typing import Any, Dict, List, Optional, Union
4
+ from typing import Any, Optional, Union
5
5
 
6
6
  _PROPERTY_TYPE_KEY = "$proptype$"
7
7
  DATASET_SCHEMA_VERSION = "1"
@@ -20,15 +20,15 @@ class FeatureStoreMetadata:
20
20
  """
21
21
 
22
22
  spine_query: str
23
- serialized_feature_views: Optional[List[str]] = None
24
- compact_feature_views: Optional[List[str]] = None
23
+ serialized_feature_views: Optional[list[str]] = None
24
+ compact_feature_views: Optional[list[str]] = None
25
25
  spine_timestamp_col: Optional[str] = None
26
26
 
27
27
  def to_json(self) -> str:
28
28
  return json.dumps(dataclasses.asdict(self))
29
29
 
30
30
  @classmethod
31
- def from_json(cls, input_json: Union[Dict[str, Any], str, bytes]) -> "FeatureStoreMetadata":
31
+ def from_json(cls, input_json: Union[dict[str, Any], str, bytes]) -> "FeatureStoreMetadata":
32
32
  if isinstance(input_json, dict):
33
33
  return cls(**input_json)
34
34
  return cls(**json.loads(input_json))
@@ -61,8 +61,8 @@ class DatasetMetadata:
61
61
 
62
62
  source_query: str
63
63
  owner: str
64
- exclude_cols: Optional[List[str]] = None
65
- label_cols: Optional[List[str]] = None
64
+ exclude_cols: Optional[list[str]] = None
65
+ label_cols: Optional[list[str]] = None
66
66
  properties: Optional[DatasetPropertiesType] = None
67
67
  schema_version: str = dataclasses.field(default=DATASET_SCHEMA_VERSION, init=False)
68
68
 
@@ -78,11 +78,11 @@ class DatasetMetadata:
78
78
  return json.dumps(state_dict)
79
79
 
80
80
  @classmethod
81
- def from_json(cls, input_json: Union[Dict[str, Any], str, bytes]) -> "DatasetMetadata":
81
+ def from_json(cls, input_json: Union[dict[str, Any], str, bytes]) -> "DatasetMetadata":
82
82
  if not input_json:
83
83
  raise ValueError("json_str was empty or None")
84
84
  try:
85
- state_dict: Dict[str, Any] = (
85
+ state_dict: dict[str, Any] = (
86
86
  input_json if isinstance(input_json, dict) else json.loads(input_json, strict=False)
87
87
  )
88
88
 
@@ -1,10 +1,11 @@
1
- from typing import Any, List, Optional, Type
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal import telemetry
5
5
  from snowflake.ml._internal.lineage import lineage_utils
6
6
  from snowflake.ml.data import data_connector, data_ingestor, data_source, ingestor_utils
7
7
  from snowflake.ml.fileset import snowfs
8
+ from snowflake.snowpark._internal import utils as snowpark_utils
8
9
 
9
10
  _PROJECT = "Dataset"
10
11
  _SUBPROJECT = "DatasetReader"
@@ -24,21 +25,21 @@ class DatasetReader(data_connector.DataConnector):
24
25
 
25
26
  self._session: snowpark.Session = snowpark_session
26
27
  self._fs: snowfs.SnowFileSystem = ingestor_utils.get_dataset_filesystem(self._session)
27
- self._files: Optional[List[str]] = None
28
+ self._files: Optional[list[str]] = None
28
29
 
29
30
  @classmethod
30
31
  def from_dataframe(
31
- cls, df: snowpark.DataFrame, ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None, **kwargs: Any
32
+ cls, df: snowpark.DataFrame, ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None, **kwargs: Any
32
33
  ) -> "DatasetReader":
33
34
  # Block superclass constructor from Snowpark DataFrames
34
35
  raise RuntimeError("Creating DatasetReader from DataFrames not supported")
35
36
 
36
- def _list_files(self) -> List[str]:
37
+ def _list_files(self) -> list[str]:
37
38
  """Private helper function that lists all files in this DatasetVersion and caches the results."""
38
39
  if self._files:
39
40
  return self._files
40
41
 
41
- files: List[str] = []
42
+ files: list[str] = []
42
43
  for source in self.data_sources:
43
44
  assert isinstance(source, data_source.DatasetInfo)
44
45
  files.extend(ingestor_utils.get_dataset_files(self._session, source, filesystem=self._fs))
@@ -48,7 +49,7 @@ class DatasetReader(data_connector.DataConnector):
48
49
  return self._files
49
50
 
50
51
  @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
51
- def files(self) -> List[str]:
52
+ def files(self) -> list[str]:
52
53
  """Get the list of remote file paths for the current DatasetVersion.
53
54
 
54
55
  The file paths follows the snow protocol.
@@ -91,10 +92,13 @@ class DatasetReader(data_connector.DataConnector):
91
92
  For example, an OBJECT column may be scanned back as a STRING column.
92
93
  """
93
94
  file_path_pattern = ".*data_.*[.]parquet"
94
- dfs: List[snowpark.DataFrame] = []
95
+ dfs: list[snowpark.DataFrame] = []
95
96
  for source in self.data_sources:
96
97
  assert isinstance(source, data_source.DatasetInfo) and source.url is not None
97
- df = self._session.read.option("pattern", file_path_pattern).parquet(source.url)
98
+ stage_reader = self._session.read.option("pattern", file_path_pattern)
99
+ if "INFER_SCHEMA_OPTIONS" in snowpark_utils.NON_FORMAT_TYPE_OPTIONS:
100
+ stage_reader = stage_reader.option("INFER_SCHEMA_OPTIONS", {"MAX_FILE_COUNT": 1})
101
+ df = stage_reader.parquet(source.url)
98
102
  if only_feature_cols and source.exclude_cols:
99
103
  df = df.drop(source.exclude_cols)
100
104
  dfs.append(df)
@@ -4,7 +4,7 @@ from snowflake.ml._internal import init_utils
4
4
 
5
5
  from .access_manager import setup_feature_store
6
6
 
7
- pkg_dir = os.path.dirname(os.path.abspath(__file__))
7
+ pkg_dir = os.path.dirname(__file__)
8
8
  pkg_name = __name__
9
9
  exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
10
10
  for k, v in exportable_classes.items():
@@ -1,6 +1,6 @@
1
1
  from dataclasses import asdict, dataclass
2
2
  from enum import Enum
3
- from typing import Dict, List, Optional
3
+ from typing import Optional
4
4
  from warnings import warn
5
5
 
6
6
  from snowflake.ml._internal import telemetry
@@ -28,7 +28,7 @@ class _FeatureStoreRole(Enum):
28
28
  class _Privilege:
29
29
  object_type: str
30
30
  object_name: str
31
- privileges: List[str]
31
+ privileges: list[str]
32
32
  scope: Optional[str] = None
33
33
  optional: bool = False
34
34
 
@@ -41,7 +41,7 @@ class _SessionInfo:
41
41
 
42
42
 
43
43
  # Lists of permissions as tuples of (OBJECT_TYPE, [PRIVILEGES, ...])
44
- _PRE_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
44
+ _PRE_INIT_PRIVILEGES: dict[_FeatureStoreRole, list[_Privilege]] = {
45
45
  _FeatureStoreRole.PRODUCER: [
46
46
  _Privilege("DATABASE", "{database}", ["USAGE"]),
47
47
  _Privilege("SCHEMA", "{database}.{schema}", ["USAGE"]),
@@ -78,7 +78,7 @@ _PRE_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
78
78
  _FeatureStoreRole.NONE: [],
79
79
  }
80
80
 
81
- _POST_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
81
+ _POST_INIT_PRIVILEGES: dict[_FeatureStoreRole, list[_Privilege]] = {
82
82
  _FeatureStoreRole.PRODUCER: [
83
83
  _Privilege("TAG", f"{{database}}.{{schema}}.{_FEATURE_VIEW_METADATA_TAG}", ["APPLY"]),
84
84
  _Privilege("TAG", f"{{database}}.{{schema}}.{_FEATURE_STORE_OBJECT_TAG}", ["APPLY"]),
@@ -89,7 +89,7 @@ _POST_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
89
89
 
90
90
 
91
91
  def _grant_privileges(
92
- session: Session, role_name: str, privileges: List[_Privilege], session_info: _SessionInfo
92
+ session: Session, role_name: str, privileges: list[_Privilege], session_info: _SessionInfo
93
93
  ) -> None:
94
94
  session_info_dict = asdict(session_info)
95
95
  for p in privileges:
@@ -129,7 +129,7 @@ def _grant_privileges(
129
129
  def _configure_pre_init_privileges(
130
130
  session: Session,
131
131
  session_info: _SessionInfo,
132
- roles_to_create: Dict[_FeatureStoreRole, str],
132
+ roles_to_create: dict[_FeatureStoreRole, str],
133
133
  ) -> None:
134
134
  """
135
135
  Configure Feature Store role privileges. Must be run with ACCOUNTADMIN
@@ -172,7 +172,7 @@ def _configure_pre_init_privileges(
172
172
  def _configure_post_init_privileges(
173
173
  session: Session,
174
174
  session_info: _SessionInfo,
175
- roles_to_create: Dict[_FeatureStoreRole, str],
175
+ roles_to_create: dict[_FeatureStoreRole, str],
176
176
  ) -> None:
177
177
  for role_type, role in roles_to_create.items():
178
178
  _grant_privileges(session, role, _POST_INIT_PRIVILEGES[role_type], session_info)