snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +82 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/utils/identifier.py +4 -2
  12. snowflake/ml/data/__init__.py +3 -0
  13. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  14. snowflake/ml/data/data_connector.py +53 -11
  15. snowflake/ml/data/data_ingestor.py +2 -1
  16. snowflake/ml/data/torch_utils.py +18 -5
  17. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  18. snowflake/ml/fileset/fileset.py +18 -18
  19. snowflake/ml/model/_client/model/model_version_impl.py +5 -3
  20. snowflake/ml/model/_client/ops/model_ops.py +2 -6
  21. snowflake/ml/model/_client/sql/model_version.py +11 -0
  22. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  23. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  25. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  26. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  27. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  28. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  29. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  30. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  31. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  32. snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
  33. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  34. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
  35. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  36. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  37. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
  38. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  39. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  40. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  42. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  43. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  44. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  45. snowflake/ml/model/_signatures/pandas_handler.py +1 -1
  46. snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
  47. snowflake/ml/model/type_hints.py +1 -0
  48. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  49. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  50. snowflake/ml/modeling/pipeline/pipeline.py +6 -176
  51. snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
  52. snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
  53. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
  54. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
  55. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
  56. snowflake/ml/registry/_manager/model_manager.py +70 -33
  57. snowflake/ml/registry/registry.py +41 -22
  58. snowflake/ml/version.py +1 -1
  59. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +38 -9
  60. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +63 -67
  61. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
  62. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  63. snowflake/ml/fileset/parquet_parser.py +0 -170
  64. snowflake/ml/fileset/tf_dataset.py +0 -88
  65. snowflake/ml/fileset/torch_datapipe.py +0 -57
  66. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  67. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  68. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
  69. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -1,170 +0,0 @@
1
- import collections
2
- import logging
3
- import time
4
- from typing import Any, Deque, Dict, Iterator, List
5
-
6
- import fsspec
7
- import numpy as np
8
- import numpy.typing as npt
9
- import pyarrow as pa
10
- import pyarrow.dataset as ds
11
-
12
- _EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
13
-
14
- # The row count for batches read from PyArrow Dataset. This number should be large enough so that
15
- # dataset.to_batches() would read in a very large portion of, if not entirely, a parquet file.
16
- _DEFAULT_DATASET_BATCH_SIZE = 1000000
17
-
18
-
19
- class _RecordBatchesBuffer:
20
- """A queue that stores record batches and tracks the total num of rows in it."""
21
-
22
- def __init__(self) -> None:
23
- self.buffer: Deque[pa.RecordBatch] = collections.deque()
24
- self.num_rows = 0
25
-
26
- def append(self, rb: pa.RecordBatch) -> None:
27
- self.buffer.append(rb)
28
- self.num_rows += rb.num_rows
29
-
30
- def appendleft(self, rb: pa.RecordBatch) -> None:
31
- self.buffer.appendleft(rb)
32
- self.num_rows += rb.num_rows
33
-
34
- def popleft(self) -> pa.RecordBatch:
35
- popped = self.buffer.popleft()
36
- self.num_rows -= popped.num_rows
37
- return popped
38
-
39
-
40
- class ParquetParser:
41
- """Read and parse the given parquet files and yield batched numpy array in dict.
42
-
43
- Args:
44
- file_paths: A list of parquet file URIs to read and parse.
45
- filesystem: A fsspec/pyarrow file system that is used to open given file URIs.
46
- batch_size: Specifies the size of each batch that will be yield
47
- shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
48
- the order of files, and then shuflle the order of rows in each file.
49
- drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last batch will
50
- get dropped if its size is smaller than the given batch_size.
51
-
52
- Returns:
53
- A PyTorch iterable datapipe that yields batched numpy array in dict. The keys will be the column names in
54
- the parquet files, and the value will be the column value as a list.
55
- """
56
-
57
- def __init__(
58
- self,
59
- file_paths: List[str],
60
- filesystem: fsspec.AbstractFileSystem,
61
- batch_size: int,
62
- shuffle: bool = True,
63
- drop_last_batch: bool = True,
64
- ) -> None:
65
- self._file_paths = file_paths
66
- self._fs = filesystem
67
- self._batch_size = batch_size
68
- self._dataset_batch_size = max(_DEFAULT_DATASET_BATCH_SIZE, self._batch_size)
69
- self._shuffle = shuffle
70
- self._drop_last_batch = drop_last_batch
71
-
72
- def __iter__(self) -> Iterator[Dict[str, npt.NDArray[Any]]]:
73
- """Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
74
-
75
- As we are generating batches with the exactly same length, the last few rows in each file might get left as they
76
- are not long enough to form a batch. These rows will be put into a temporary buffer and combine with the first
77
- few rows of the next file to generate a new batch.
78
-
79
- Yields:
80
- A dict mapping column names to the corresponding data fetch from that column.
81
- """
82
- self._rb_buffer = _RecordBatchesBuffer()
83
- files = list(self._file_paths)
84
- if self._shuffle:
85
- np.random.shuffle(files)
86
- pa_dataset: ds.Dataset = ds.dataset(files, format="parquet", filesystem=self._fs)
87
-
88
- for rb in _retryable_batches(pa_dataset, batch_size=self._dataset_batch_size):
89
- if self._shuffle:
90
- rb = rb.take(np.random.permutation(rb.num_rows))
91
- self._rb_buffer.append(rb)
92
- while self._rb_buffer.num_rows >= self._batch_size:
93
- yield self._get_batches_from_buffer()
94
-
95
- if self._rb_buffer.num_rows and not self._drop_last_batch:
96
- yield self._get_batches_from_buffer()
97
-
98
- def _get_batches_from_buffer(self) -> Dict[str, npt.NDArray[Any]]:
99
- """Generate new batches from the existing record batch buffer."""
100
- cnt_rbs_num_rows = 0
101
- candidates = []
102
-
103
- # Keep popping record batches in buffer until there are enough rows for a batch.
104
- while self._rb_buffer.num_rows and cnt_rbs_num_rows < self._batch_size:
105
- candidate = self._rb_buffer.popleft()
106
- cnt_rbs_num_rows += candidate.num_rows
107
- candidates.append(candidate)
108
-
109
- # When there are more rows than needed, slice the last popped batch to fit batch_size.
110
- if cnt_rbs_num_rows > self._batch_size:
111
- row_diff = cnt_rbs_num_rows - self._batch_size
112
- slice_target = candidates[-1]
113
- cut_off = slice_target.num_rows - row_diff
114
- to_merge = slice_target.slice(length=cut_off)
115
- left_over = slice_target.slice(offset=cut_off)
116
- candidates[-1] = to_merge
117
- self._rb_buffer.appendleft(left_over)
118
-
119
- res = _merge_record_batches(candidates)
120
- return _record_batch_to_arrays(res)
121
-
122
-
123
- def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
124
- """Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
125
- if not record_batches:
126
- return _EMPTY_RECORD_BATCH
127
- if len(record_batches) == 1:
128
- return record_batches[0]
129
- record_batches = list(filter(lambda rb: rb.num_rows > 0, record_batches))
130
- one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks()
131
- batches = one_chunk_table.to_batches(max_chunksize=None)
132
- return batches[0]
133
-
134
-
135
- def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
136
- """Transform the record batch to a (string, numpy array) dict."""
137
- batch_dict = {}
138
- for column, column_schema in zip(rb, rb.schema):
139
- # zero_copy_only=False because of nans. Ideally nans should have been imputed in feature engineering.
140
- array = column.to_numpy(zero_copy_only=False)
141
- batch_dict[column_schema.name] = array
142
- return batch_dict
143
-
144
-
145
- def _retryable_batches(
146
- dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
147
- ) -> Iterator[pa.RecordBatch]:
148
- """Make the Dataset to_batches retryable."""
149
- retries = 0
150
- current_batch_index = 0
151
-
152
- while True:
153
- try:
154
- for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
155
- if batch_index < current_batch_index:
156
- # Skip batches that have already been processed
157
- continue
158
-
159
- yield batch
160
- current_batch_index = batch_index + 1
161
- # Exit the loop once all batches are processed
162
- break
163
-
164
- except Exception as e:
165
- if retries < max_retries:
166
- retries += 1
167
- logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
168
- time.sleep(delay)
169
- else:
170
- raise e
@@ -1,88 +0,0 @@
1
- from typing import Any, Dict, Generator, List
2
-
3
- import fsspec
4
- import numpy.typing as npt
5
- import pyarrow as pa
6
- import pyarrow.parquet as pq
7
- import tensorflow as tf
8
-
9
- from snowflake.ml._internal.exceptions import (
10
- error_codes,
11
- exceptions as snowml_exceptions,
12
- )
13
- from snowflake.ml.fileset import parquet_parser
14
-
15
-
16
- def read_and_parse_parquet(
17
- files: List[str],
18
- filesystem: fsspec.AbstractFileSystem,
19
- batch_size: int,
20
- shuffle: bool,
21
- drop_last_batch: bool,
22
- ) -> tf.data.Dataset:
23
- """Creates a tf.data.Dataset that reads given parquet files into batched Tensors.
24
-
25
- Args:
26
- files: A list of input parquet file URIs to read and parse. The parquet files should
27
- have the same schema.
28
- filesystem: A fsspec/pyarrow file system that is used to open given file URIs.
29
- batch_size: Specifies the size of each batch that will be yield. It is preferred to
30
- set it to your training batch size, and avoid using dataset.{batch(),rebatch()} later.
31
- shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
32
- the order of files, and then shuflle the order of rows in each file. It is preferred
33
- to shuffle the data this way than dataset.unbatch().shuffle().rebatch().
34
- drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last batch will
35
- get dropped if its size is smaller than the given batch_size.
36
-
37
- Returns:
38
- A tf.data.Dataset generates batched Tensors in a dict. The keys will be the column names in
39
- the parquet files.
40
-
41
- Raises:
42
- SnowflakeMLException: if `files` is empty.
43
-
44
- Example:
45
- >>> from snowflake.ml.fileset import sfcfs, tf_dataset
46
- >>> conn = snowflake.connector.connect(**connection_parameters)
47
- >>> fs = sfcfs.SFFileSystem(conn)
48
- >>> files = fs.ls(dir_path)
49
- >>> ds = tf_dataset.parse_and_read_parquet(files, fs, batch_size = 2)
50
- >>> for batch in ds:
51
- >>> print(batch)
52
- ----
53
- {'_COL_1': <tf.Tensor: shape=(2,), dtype=float32, numpy=[32.5000, 6.0000]>,
54
- '_COL_2': <tf.Tensor: shape=(2,), dtype=float32, numpy=[-73.9542, -73.9875]>}
55
- """
56
- if not files:
57
- raise snowml_exceptions.SnowflakeMLException(
58
- error_code=error_codes.SNOWML_READ_FAILED,
59
- original_exception=ValueError("At least one file is needed to create a TF dataset."),
60
- )
61
-
62
- def generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]:
63
- yield from parquet_parser.ParquetParser(list(files), filesystem, batch_size, shuffle, drop_last_batch)
64
-
65
- return tf.data.Dataset.from_generator(generator, output_signature=_derive_signature(files[0], filesystem))
66
-
67
-
68
- def _arrow_type_to_tensor_spec(field: pa.Field) -> tf.TensorSpec:
69
- try:
70
- dtype = tf.dtypes.as_dtype(field.type.to_pandas_dtype())
71
- except TypeError:
72
- raise snowml_exceptions.SnowflakeMLException(
73
- error_code=error_codes.INVALID_DATA_TYPE,
74
- original_exception=TypeError(f"Column {field.name} has unsupportd type {field.type}."),
75
- )
76
- # First dimension is batch dimension.
77
- return tf.TensorSpec(shape=(None,), dtype=dtype)
78
-
79
-
80
- def _derive_signature(file: str, filesystem: fsspec.AbstractFileSystem) -> Dict[str, tf.TensorSpec]:
81
- """Derives the signature of the TF dataset from one parquet file."""
82
- # TODO(zpeng): pq.read_schema does not support `filesystem` until pyarrow>=10.
83
- # switch to pq.read_schema when we depend on that.
84
- schema = pq.read_table(file, filesystem=filesystem).schema
85
- # Signature:
86
- # The dataset yields dicts. Keys are column names; values are 1-D tensors (
87
- # the first dimension is batch dimension).
88
- return {field.name: _arrow_type_to_tensor_spec(field) for field in schema}
@@ -1,57 +0,0 @@
1
- from typing import Any, Dict, Iterator
2
-
3
- import fsspec
4
- import numpy.typing as npt
5
- from torchdata.datapipes.iter import IterDataPipe
6
-
7
- from snowflake.ml.fileset import parquet_parser
8
-
9
-
10
- class ReadAndParseParquet(IterDataPipe):
11
- """Read and parse the parquet files yield batched numpy array in dict.
12
-
13
- Args:
14
- input_datapipe: A datapipe of input parquet file URIs to read and parse.
15
- Note that the datapipe must be finite.
16
- filesystem: A fsspec/pyarrow file system that is used to open given file URIs.
17
- batch_size: Specifies the size of each batch that will be yield
18
- shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
19
- the order of files, and then shuflle the order of rows in each file.
20
- drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last batch will
21
- get dropped if its size is smaller than the given batch_size.
22
-
23
- Returns:
24
- A PyTorch iterable datapipe that yields batched numpy array in dict. The keys will be the column names in
25
- the parquet files.
26
-
27
- Example:
28
- >>> from snowflake.ml.fileset import sfcfs, torch_datapipe
29
- >>> from torchdata.datapipes.iter import FSSpecFileLister
30
- >>> conn = snowflake.connector.connect(**connection_parameters)
31
- >>> fs = sfcfs.SFFileSystem(conn)
32
- >>> filedp = FSSpecFileLister(root=dir_path, masks="*.parquet", mode="rb", sf_connection=conn)
33
- >>> parquet_dp = torch_datapipe.ReadAndParseParquet(file_dp, fs, batch_size = 2)
34
- >>> for batch in parquet_dp:
35
- >>> print(batch)
36
- ----
37
- {'_COL_1': [32.5000, 6.0000], '_COL_2': [-73.9542, -73.9875]}
38
- """
39
-
40
- def __init__(
41
- self,
42
- input_datapipe: IterDataPipe[str],
43
- filesystem: fsspec.AbstractFileSystem,
44
- batch_size: int,
45
- shuffle: bool,
46
- drop_last_batch: bool,
47
- ) -> None:
48
- self._input_datapipe = input_datapipe
49
- self._fs = filesystem
50
- self._batch_size = batch_size
51
- self._shuffle = shuffle
52
- self._drop_last_batch = drop_last_batch
53
-
54
- def __iter__(self) -> Iterator[Dict[str, npt.NDArray[Any]]]:
55
- yield from parquet_parser.ParquetParser(
56
- list(self._input_datapipe), self._fs, self._batch_size, self._shuffle, self._drop_last_batch
57
- )
@@ -1,151 +0,0 @@
1
- from typing import Any, List, Optional
2
-
3
- from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import (
4
- SnowparkTransformHandlers,
5
- )
6
- from snowflake.snowpark import DataFrame, Session
7
-
8
-
9
- class MLRuntimeTransformHandlers:
10
- def __init__(
11
- self,
12
- dataset: DataFrame,
13
- estimator: object,
14
- class_name: str,
15
- subproject: str,
16
- autogenerated: Optional[bool] = False,
17
- ) -> None:
18
- """
19
- Args:
20
- dataset: The dataset to run transform functions on.
21
- estimator: The estimator used to run transforms.
22
- class_name: class name to be used in telemetry.
23
- subproject: subproject to be used in telemetry.
24
- autogenerated: Whether the class was autogenerated from a template.
25
-
26
- Raises:
27
- ModuleNotFoundError: The mlruntimes_client module is not available.
28
- """
29
- try:
30
- from snowflake.ml.runtime import MLRuntimeClient
31
- except ModuleNotFoundError as e:
32
- # This is an internal exception, not a user-facing one. The snowflake.ml.runtime module should
33
- # always be present when this class is instantiated.
34
- raise e
35
-
36
- self.client = MLRuntimeClient()
37
- self.dataset = dataset
38
- self.estimator = estimator
39
- self._class_name = class_name
40
- self._subproject = subproject
41
- self._autogenerated = autogenerated
42
-
43
- def batch_inference(
44
- self,
45
- inference_method: str,
46
- input_cols: List[str],
47
- expected_output_cols: List[str],
48
- session: Session,
49
- dependencies: List[str],
50
- drop_input_cols: Optional[bool] = False,
51
- expected_output_cols_type: Optional[str] = "",
52
- *args: Any,
53
- **kwargs: Any,
54
- ) -> DataFrame:
55
- """Run batch inference on the given dataset.
56
- Temporary workaround - pushdown implementation is not currently ready for batch_inference.
57
- We use a SnowparkTransformHandlers until we have a way to use the runtime client.
58
-
59
- Args:
60
- inference_method: the name of the method used by `estimator` to run inference.
61
- input_cols: List of feature columns for inference.
62
- session: An active Snowpark Session.
63
- dependencies: List of dependencies for the transformer.
64
- expected_output_cols: column names (in order) of the output dataset.
65
- drop_input_cols: Boolean to determine whether to drop the input columns from the output dataset.
66
- expected_output_cols_type: Expected type of the output columns.
67
- args: additional positional arguments.
68
- kwargs: additional keyword args.
69
-
70
- Returns:
71
- A new dataset of the same type as the input dataset.
72
-
73
- """
74
-
75
- mlrs_inference_methods = ["predict", "predict_proba", "predict_log_proba"]
76
-
77
- if inference_method in mlrs_inference_methods:
78
- result_df = self.client.inference(
79
- estimator=self.estimator,
80
- dataset=self.dataset,
81
- inference_method=inference_method,
82
- input_cols=input_cols,
83
- output_cols=expected_output_cols,
84
- drop_input_cols=drop_input_cols,
85
- )
86
-
87
- else:
88
- handler = SnowparkTransformHandlers(
89
- dataset=self.dataset,
90
- estimator=self.estimator,
91
- class_name=self._class_name,
92
- subproject=self._subproject,
93
- autogenerated=self._autogenerated,
94
- )
95
- result_df = handler.batch_inference(
96
- inference_method,
97
- input_cols,
98
- expected_output_cols,
99
- session,
100
- dependencies,
101
- drop_input_cols,
102
- expected_output_cols_type,
103
- *args,
104
- **kwargs,
105
- )
106
-
107
- assert isinstance(result_df, DataFrame) # mypy - The MLRS return types are annotated as `object`.
108
- return result_df
109
-
110
- def score(
111
- self,
112
- input_cols: List[str],
113
- label_cols: List[str],
114
- session: Session,
115
- dependencies: List[str],
116
- score_sproc_imports: List[str],
117
- sample_weight_col: Optional[str] = None,
118
- *args: Any,
119
- **kwargs: Any,
120
- ) -> float:
121
- """Score the given test dataset.
122
-
123
- Args:
124
- session: An active Snowpark Session.
125
- dependencies: score function dependencies.
126
- score_sproc_imports: imports for score stored procedure.
127
- input_cols: List of feature columns for inference.
128
- label_cols: List of label columns for scoring.
129
- sample_weight_col: A column assigning relative weights to each row for scoring.
130
- args: additional positional arguments.
131
- kwargs: additional keyword args.
132
-
133
-
134
- Returns:
135
- An accuracy score for the model on the given test data.
136
-
137
- Raises:
138
- TypeError: The ML Runtimes client returned a non-float result
139
- """
140
- output_score = self.client.score(
141
- estimator=self.estimator,
142
- dataset=self.dataset,
143
- input_cols=input_cols,
144
- label_cols=label_cols,
145
- sample_weight_col=sample_weight_col,
146
- )
147
- if not isinstance(output_score, float):
148
- raise TypeError(
149
- f"The ML Runtimes Client returned a non-float value {output_score} of type {type(output_score)}"
150
- )
151
- return output_score
@@ -1,66 +0,0 @@
1
- from typing import List, Optional
2
-
3
- from snowflake.snowpark import DataFrame, Session
4
-
5
-
6
- class MLRuntimeModelTrainer:
7
- """ML model training using the ml runties client."""
8
-
9
- def __init__(
10
- self,
11
- estimator: object,
12
- dataset: DataFrame,
13
- session: Session,
14
- input_cols: List[str],
15
- label_cols: Optional[List[str]],
16
- sample_weight_col: Optional[str],
17
- autogenerated: bool = False,
18
- subproject: str = "",
19
- ) -> None:
20
- """
21
- Initializes the MLRuntimeModelTrainer with a model, a Snowpark DataFrame, feature, and label column names.
22
-
23
- Args:
24
- estimator: SKLearn compatible estimator or transformer object.
25
- dataset: The dataset used for training the model.
26
- session: Snowflake session object to be used for training.
27
- input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training.
28
- label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn.
29
- sample_weight_col: The column name representing the weight of training examples.
30
- autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not.
31
- subproject: subproject name to be used in telemetry.
32
-
33
- Raises:
34
- ModuleNotFoundError: The mlruntimes_client module is not available.
35
- """
36
-
37
- try:
38
- from snowflake.ml.runtime import MLRuntimeClient
39
- except ModuleNotFoundError as e:
40
- # This is an internal exception, not a user-facing one. The snowflake.ml.runtime module should
41
- # always be present when this class is instantiated.
42
- raise e
43
-
44
- self.client = MLRuntimeClient()
45
-
46
- self.estimator = estimator
47
- self.dataset = dataset
48
- self.session = session
49
- self.input_cols = input_cols
50
- self.label_cols = label_cols
51
- self.sample_weight_col = sample_weight_col
52
- self._autogenerated = autogenerated
53
- self._subproject = subproject
54
- self._class_name = estimator.__class__.__name__
55
-
56
- def train(self) -> object:
57
- """
58
- Trains the model by pushing down the compute into SPCS ML Runtime
59
- """
60
- return self.client.train(
61
- estimator=self.estimator,
62
- dataset=self.dataset,
63
- input_cols=self.input_cols,
64
- label_cols=self.label_cols,
65
- sample_weight_col=self.sample_weight_col,
66
- )