snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +67 -10
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  6. snowflake/ml/_internal/telemetry.py +12 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  8. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  9. snowflake/ml/data/data_connector.py +133 -0
  10. snowflake/ml/data/data_ingestor.py +28 -0
  11. snowflake/ml/data/data_source.py +23 -0
  12. snowflake/ml/dataset/dataset.py +1 -13
  13. snowflake/ml/dataset/dataset_reader.py +18 -118
  14. snowflake/ml/feature_store/access_manager.py +7 -1
  15. snowflake/ml/feature_store/entity.py +19 -2
  16. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  18. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  19. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  20. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  21. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  22. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  23. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  24. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  25. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  26. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  27. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  28. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  29. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  30. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  31. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  32. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  33. snowflake/ml/feature_store/feature_store.py +579 -53
  34. snowflake/ml/feature_store/feature_view.py +168 -5
  35. snowflake/ml/fileset/stage_fs.py +18 -10
  36. snowflake/ml/lineage/lineage_node.py +1 -1
  37. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  38. snowflake/ml/model/_model_composer/model_composer.py +11 -14
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +24 -16
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  41. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  42. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  43. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  44. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  45. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  46. snowflake/ml/model/_packager/model_handlers/_base.py +11 -1
  47. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  48. snowflake/ml/model/_packager/model_handlers/catboost.py +42 -0
  49. snowflake/ml/model/_packager/model_handlers/lightgbm.py +68 -0
  50. snowflake/ml/model/_packager/model_handlers/xgboost.py +59 -0
  51. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  52. snowflake/ml/model/model_signature.py +4 -4
  53. snowflake/ml/model/type_hints.py +4 -0
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  56. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  57. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  58. snowflake/ml/registry/registry.py +100 -13
  59. snowflake/ml/version.py +1 -1
  60. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +48 -2
  61. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +64 -42
  62. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  63. snowflake/ml/_internal/lineage/data_source.py +0 -10
  64. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -126,9 +126,11 @@ class FeatureView(lineage_node.LineageNode):
126
126
  name: str,
127
127
  entities: List[Entity],
128
128
  feature_df: DataFrame,
129
+ *,
129
130
  timestamp_col: Optional[str] = None,
130
131
  refresh_freq: Optional[str] = None,
131
132
  desc: str = "",
133
+ warehouse: Optional[str] = None,
132
134
  **_kwargs: Any,
133
135
  ) -> None:
134
136
  """
@@ -149,7 +151,33 @@ class FeatureView(lineage_node.LineageNode):
149
151
  NOTE: If refresh_freq is not provided, then FeatureView will be registered as View on Snowflake backend
150
152
  and there won't be extra storage cost.
151
153
  desc: description of the FeatureView.
154
+ warehouse: warehouse to refresh feature view. Not needed for static feature view (refresh_freq is None).
155
+ For managed feature view, this warehouse will overwrite the default warehouse of Feature Store if it is
156
+ specified, otherwise the default warehouse will be used.
152
157
  _kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
158
+
159
+ Example::
160
+
161
+ >>> fs = FeatureStore(...)
162
+ >>> # draft_fv is a local object that hasn't materiaized to Snowflake backend yet.
163
+ >>> feature_df = session.sql("select f_1, f_2 from source_table")
164
+ >>> draft_fv = FeatureView(
165
+ ... name="my_fv",
166
+ ... entities=[e1, e2],
167
+ ... feature_df=feature_df,
168
+ ... timestamp_col='TS', # optional
169
+ ... refresh_freq='1d', # optional
170
+ ... desc='A line about this feature view', # optional
171
+ ... warehouse='WH' # optional, the warehouse used to refresh (managed) feature view
172
+ ... )
173
+ >>> print(draft_fv.status)
174
+ FeatureViewStatus.DRAFT
175
+ <BLANKLINE>
176
+ >>> # registered_fv is a local object that maps to a Snowflake backend object.
177
+ >>> registered_fv = fs.register_feature_view(draft_fv, "v1")
178
+ >>> print(registered_fv.status)
179
+ FeatureViewStatus.ACTIVE
180
+
153
181
  """
154
182
 
155
183
  self._name: SqlIdentifier = SqlIdentifier(name)
@@ -167,7 +195,7 @@ class FeatureView(lineage_node.LineageNode):
167
195
  self._refresh_freq: Optional[str] = refresh_freq
168
196
  self._database: Optional[SqlIdentifier] = None
169
197
  self._schema: Optional[SqlIdentifier] = None
170
- self._warehouse: Optional[SqlIdentifier] = None
198
+ self._warehouse: Optional[SqlIdentifier] = SqlIdentifier(warehouse) if warehouse is not None else None
171
199
  self._refresh_mode: Optional[str] = None
172
200
  self._refresh_mode_reason: Optional[str] = None
173
201
  self._owner: Optional[str] = None
@@ -185,6 +213,33 @@ class FeatureView(lineage_node.LineageNode):
185
213
 
186
214
  Raises:
187
215
  ValueError: if selected feature names is not found in the FeatureView.
216
+
217
+ Example::
218
+
219
+ >>> fs = FeatureStore(...)
220
+ >>> e = fs.get_entity('TRIP_ID')
221
+ >>> # feature_df contains 3 features and 1 entity
222
+ >>> feature_df = session.table(source_table).select(
223
+ ... 'TRIPDURATION',
224
+ ... 'START_STATION_LATITUDE',
225
+ ... 'END_STATION_LONGITUDE',
226
+ ... 'TRIP_ID'
227
+ ... )
228
+ >>> darft_fv = FeatureView(name='F_TRIP', entities=[e], feature_df=feature_df)
229
+ >>> fv = fs.register_feature_view(darft_fv, version='1.0')
230
+ >>> # shows all 3 features
231
+ >>> fv.feature_names
232
+ ['TRIPDURATION', 'START_STATION_LATITUDE', 'END_STATION_LONGITUDE']
233
+ <BLANKLINE>
234
+ >>> # slice a subset of features
235
+ >>> fv_slice = fv.slice(['TRIPDURATION', 'START_STATION_LATITUDE'])
236
+ >>> fv_slice.names
237
+ ['TRIPDURATION', 'START_STATION_LATITUDE']
238
+ <BLANKLINE>
239
+ >>> # query the full set of features in original feature view
240
+ >>> fv_slice.feature_view_ref.feature_names
241
+ ['TRIPDURATION', 'START_STATION_LATITUDE', 'END_STATION_LONGITUDE']
242
+
188
243
  """
189
244
 
190
245
  res = []
@@ -196,14 +251,30 @@ class FeatureView(lineage_node.LineageNode):
196
251
  return FeatureViewSlice(self, res)
197
252
 
198
253
  def fully_qualified_name(self) -> str:
199
- """Returns the fully qualified name (<database_name>.<schema_name>.<feature_view_name>) for the
200
- FeatureView in Snowflake.
254
+ """
255
+ Returns the fully qualified name (<database_name>.<schema_name>.<feature_view_name>) for the
256
+ FeatureView in Snowflake.
201
257
 
202
258
  Returns:
203
259
  fully qualified name string.
204
260
 
205
261
  Raises:
206
262
  RuntimeError: if the FeatureView is not registered.
263
+
264
+ Example::
265
+
266
+ >>> fs = FeatureStore(...)
267
+ >>> e = fs.get_entity('TRIP_ID')
268
+ >>> feature_df = session.table(source_table).select(
269
+ ... 'TRIPDURATION',
270
+ ... 'START_STATION_LATITUDE',
271
+ ... 'TRIP_ID'
272
+ ... )
273
+ >>> darft_fv = FeatureView(name='F_TRIP', entities=[e], feature_df=feature_df)
274
+ >>> registered_fv = fs.register_feature_view(darft_fv, version='1.0')
275
+ >>> registered_fv.fully_qualified_name()
276
+ 'MY_DB.MY_SCHEMA."F_TRIP$1.0"'
277
+
207
278
  """
208
279
  if self.status == FeatureViewStatus.DRAFT or self.version is None:
209
280
  raise RuntimeError(f"FeatureView {self.name} has not been registered.")
@@ -221,6 +292,22 @@ class FeatureView(lineage_node.LineageNode):
221
292
 
222
293
  Raises:
223
294
  ValueError: if feature name is not found in the FeatureView.
295
+
296
+ Example::
297
+
298
+ >>> fs = FeatureStore(...)
299
+ >>> e = fs.get_entity('TRIP_ID')
300
+ >>> feature_df = session.table(source_table).select('TRIPDURATION', 'START_STATION_LATITUDE', 'TRIP_ID')
301
+ >>> draft_fv = FeatureView(name='F_TRIP', entities=[e], feature_df=feature_df)
302
+ >>> draft_fv = draft_fv.attach_feature_desc({
303
+ ... "TRIPDURATION": "Duration of a trip.",
304
+ ... "START_STATION_LATITUDE": "Latitude of the start station."
305
+ ... })
306
+ >>> registered_fv = fs.register_feature_view(draft_fv, version='1.0')
307
+ >>> registered_fv.feature_descs
308
+ OrderedDict([('TRIPDURATION', 'Duration of a trip.'),
309
+ ('START_STATION_LATITUDE', 'Latitude of the start station.')])
310
+
224
311
  """
225
312
  for f, d in descs.items():
226
313
  f = SqlIdentifier(f)
@@ -254,6 +341,31 @@ class FeatureView(lineage_node.LineageNode):
254
341
 
255
342
  @desc.setter
256
343
  def desc(self, new_value: str) -> None:
344
+ """Set the description of feature view.
345
+
346
+ Args:
347
+ new_value: new value of description.
348
+
349
+ Example::
350
+
351
+ >>> fs = FeatureStore(...)
352
+ >>> e = fs.get_entity('TRIP_ID')
353
+ >>> darft_fv = FeatureView(
354
+ ... name='F_TRIP',
355
+ ... entities=[e],
356
+ ... feature_df=feature_df,
357
+ ... desc='old desc'
358
+ ... )
359
+ >>> fv_1 = fs.register_feature_view(darft_fv, version='1.0')
360
+ >>> print(fv_1.desc)
361
+ old desc
362
+ <BLANKLINE>
363
+ >>> darft_fv.desc = 'NEW DESC'
364
+ >>> fv_2 = fs.register_feature_view(darft_fv, version='2.0')
365
+ >>> print(fv_2.desc)
366
+ NEW DESC
367
+
368
+ """
257
369
  warnings.warn(
258
370
  "You must call register_feature_view() to make it effective. "
259
371
  "Or use update_feature_view(desc=<new_value>).",
@@ -288,6 +400,31 @@ class FeatureView(lineage_node.LineageNode):
288
400
 
289
401
  @refresh_freq.setter
290
402
  def refresh_freq(self, new_value: str) -> None:
403
+ """Set refresh frequency of feature view.
404
+
405
+ Args:
406
+ new_value: The new value of refresh frequency.
407
+
408
+ Example::
409
+
410
+ >>> fs = FeatureStore(...)
411
+ >>> e = fs.get_entity('TRIP_ID')
412
+ >>> darft_fv = FeatureView(
413
+ ... name='F_TRIP',
414
+ ... entities=[e],
415
+ ... feature_df=feature_df,
416
+ ... refresh_freq='1d'
417
+ ... )
418
+ >>> fv_1 = fs.register_feature_view(darft_fv, version='1.0')
419
+ >>> print(fv_1.refresh_freq)
420
+ 1 day
421
+ <BLANKLINE>
422
+ >>> darft_fv.refresh_freq = '12h'
423
+ >>> fv_2 = fs.register_feature_view(darft_fv, version='2.0')
424
+ >>> print(fv_2.refresh_freq)
425
+ 12 hours
426
+
427
+ """
291
428
  warnings.warn(
292
429
  "You must call register_feature_view() to make it effective. "
293
430
  "Or use update_feature_view(refresh_freq=<new_value>).",
@@ -310,6 +447,32 @@ class FeatureView(lineage_node.LineageNode):
310
447
 
311
448
  @warehouse.setter
312
449
  def warehouse(self, new_value: str) -> None:
450
+ """Set warehouse of feature view.
451
+
452
+ Args:
453
+ new_value: The new value of warehouse.
454
+
455
+ Example::
456
+
457
+ >>> fs = FeatureStore(...)
458
+ >>> e = fs.get_entity('TRIP_ID')
459
+ >>> darft_fv = FeatureView(
460
+ ... name='F_TRIP',
461
+ ... entities=[e],
462
+ ... feature_df=feature_df,
463
+ ... refresh_freq='1d',
464
+ ... warehouse='WH1',
465
+ ... )
466
+ >>> fv_1 = fs.register_feature_view(darft_fv, version='1.0')
467
+ >>> print(fv_1.warehouse)
468
+ WH1
469
+ <BLANKLINE>
470
+ >>> darft_fv.warehouse = 'WH2'
471
+ >>> fv_2 = fs.register_feature_view(darft_fv, version='2.0')
472
+ >>> print(fv_2.warehouse)
473
+ WH2
474
+
475
+ """
313
476
  warnings.warn(
314
477
  "You must call register_feature_view() to make it effective. "
315
478
  "Or use update_feature_view(warehouse=<new_value>).",
@@ -456,7 +619,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
456
619
 
457
620
  entities = []
458
621
  for e_json in json_dict["_entities"]:
459
- e = Entity(e_json["name"], e_json["join_keys"], e_json["desc"])
622
+ e = Entity(e_json["name"], e_json["join_keys"], desc=e_json["desc"])
460
623
  e.owner = e_json["owner"]
461
624
  entities.append(e)
462
625
 
@@ -504,7 +667,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
504
667
  original_exception=ValueError("No active warehouse selected in the current session"),
505
668
  )
506
669
 
507
- fs = feature_store.FeatureStore(session, db_name, feature_store_name, session_warehouse)
670
+ fs = feature_store.FeatureStore(session, db_name, feature_store_name, default_warehouse=session_warehouse)
508
671
  return fs.get_feature_view(feature_view_name, version) # type: ignore[no-any-return]
509
672
 
510
673
  @staticmethod
@@ -234,21 +234,29 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
234
234
 
235
235
  Raises:
236
236
  SnowflakeMLException: An error occurred when the given path points to a file that cannot be found.
237
+ snowpark_exceptions.SnowparkClientException: File access failed with a Snowpark exception
237
238
  """
238
239
  path = path.lstrip("/")
239
240
  if self._USE_FALLBACK_FILE_ACCESS:
240
241
  return self._open_with_snowpark(path)
241
242
  cached_presigned_url = self._url_cache.get(path, None)
242
- if not cached_presigned_url:
243
- res = self._fetch_presigned_urls([path])
244
- url = res[0][1]
245
- expire_at = time.time() + _PRESIGNED_URL_LIFETIME_SEC
246
- cached_presigned_url = _PresignedUrl(url, expire_at)
247
- self._url_cache[path] = cached_presigned_url
248
- logging.debug(f"Retrieved presigned url for {path}.")
249
- elif cached_presigned_url.is_expiring():
250
- self.optimize_read()
251
- cached_presigned_url = self._url_cache[path]
243
+ try:
244
+ if not cached_presigned_url:
245
+ res = self._fetch_presigned_urls([path])
246
+ url = res[0][1]
247
+ expire_at = time.time() + _PRESIGNED_URL_LIFETIME_SEC
248
+ cached_presigned_url = _PresignedUrl(url, expire_at)
249
+ self._url_cache[path] = cached_presigned_url
250
+ logging.debug(f"Retrieved presigned url for {path}.")
251
+ elif cached_presigned_url.is_expiring():
252
+ self.optimize_read()
253
+ cached_presigned_url = self._url_cache[path]
254
+ except snowpark_exceptions.SnowparkClientException as e:
255
+ if self._USE_FALLBACK_FILE_ACCESS == False: # noqa: E712 # Fallback disabled
256
+ raise
257
+ # This may be an intermittent failure, so don't set _USE_FALLBACK_FILE_ACCESS = True
258
+ logging.warning(f"Pre-signed URL generation failed with {e.message}, trying fallback file access")
259
+ return self._open_with_snowpark(path)
252
260
  url = cached_presigned_url.url
253
261
  try:
254
262
  return self._fs._open(url, mode=mode, **kwargs)
@@ -118,7 +118,7 @@ class LineageNode:
118
118
  )
119
119
  domain = lineage_object["domain"].lower()
120
120
  if domain_filter is None or domain in domain_filter:
121
- if domain in DOMAIN_LINEAGE_REGISTRY:
121
+ if domain in DOMAIN_LINEAGE_REGISTRY and lineage_object["status"] == "ACTIVE":
122
122
  lineage_nodes.append(
123
123
  DOMAIN_LINEAGE_REGISTRY[domain]._load_from_lineage_node(
124
124
  self._session, lineage_object["name"], lineage_object.get("version")
@@ -85,9 +85,8 @@ def _run_setup() -> None:
85
85
 
86
86
  TARGET_METHOD = os.getenv("TARGET_METHOD")
87
87
 
88
- _concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", None)
89
-
90
- _CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env) if _concurrent_requests_max_env else None
88
+ _concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", "1")
89
+ _CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env)
91
90
 
92
91
  with tempfile.TemporaryDirectory() as tmp_dir:
93
92
  if zipfile.is_zipfile(model_zip_stage_path):
@@ -11,7 +11,8 @@ from packaging import requirements
11
11
  from typing_extensions import deprecated
12
12
 
13
13
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
- from snowflake.ml._internal.lineage import data_source, lineage_utils
14
+ from snowflake.ml._internal.lineage import lineage_utils
15
+ from snowflake.ml.data import data_source
15
16
  from snowflake.ml.model import model_signature, type_hints as model_types
16
17
  from snowflake.ml.model._model_composer.model_manifest import model_manifest
17
18
  from snowflake.ml.model._packager import model_packager
@@ -128,16 +129,14 @@ class ModelComposer:
128
129
  file_utils.copytree(
129
130
  str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
130
131
  )
131
-
132
- file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
133
-
134
- self.manifest.save(
135
- session=self.session,
136
- model_meta=model_metadata,
137
- model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
138
- options=options,
139
- data_sources=self._get_data_sources(model, sample_input_data),
140
- )
132
+ self.manifest.save(
133
+ model_meta=self.packager.meta,
134
+ model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
135
+ options=options,
136
+ data_sources=self._get_data_sources(model, sample_input_data),
137
+ )
138
+ else:
139
+ file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
141
140
 
142
141
  file_utils.upload_directory_to_stage(
143
142
  self.session,
@@ -186,6 +185,4 @@ class ModelComposer:
186
185
  data_sources = lineage_utils.get_data_sources(model)
187
186
  if not data_sources and sample_input_data is not None:
188
187
  data_sources = lineage_utils.get_data_sources(sample_input_data)
189
- if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
190
- return data_sources
191
- return None
188
+ return data_sources
@@ -1,11 +1,12 @@
1
1
  import collections
2
2
  import copy
3
3
  import pathlib
4
+ import warnings
4
5
  from typing import List, Optional, cast
5
6
 
6
7
  import yaml
7
8
 
8
- from snowflake.ml._internal.lineage import data_source
9
+ from snowflake.ml.data import data_source
9
10
  from snowflake.ml.model import type_hints
10
11
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
11
12
  from snowflake.ml.model._model_composer.model_method import (
@@ -16,7 +17,6 @@ from snowflake.ml.model._packager.model_meta import (
16
17
  model_meta as model_meta_api,
17
18
  model_meta_schema,
18
19
  )
19
- from snowflake.snowpark import Session
20
20
 
21
21
 
22
22
  class ModelManifest:
@@ -36,9 +36,8 @@ class ModelManifest:
36
36
 
37
37
  def save(
38
38
  self,
39
- session: Session,
40
39
  model_meta: model_meta_api.ModelMetadata,
41
- model_file_rel_path: pathlib.PurePosixPath,
40
+ model_rel_path: pathlib.PurePosixPath,
42
41
  options: Optional[type_hints.ModelSaveOption] = None,
43
42
  data_sources: Optional[List[data_source.DataSource]] = None,
44
43
  ) -> None:
@@ -47,10 +46,10 @@ class ModelManifest:
47
46
 
48
47
  runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
49
48
  runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
50
- runtime_to_use.imports.append(model_file_rel_path)
49
+ runtime_to_use.imports.append(str(model_rel_path) + "/")
51
50
  runtime_dict = runtime_to_use.save(self.workspace_path)
52
51
 
53
- self.function_generator = function_generator.FunctionGenerator(model_file_rel_path=model_file_rel_path)
52
+ self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
54
53
  self.methods: List[model_method.ModelMethod] = []
55
54
  for target_method in model_meta.signatures.keys():
56
55
  method = model_method.ModelMethod(
@@ -75,6 +74,16 @@ class ModelManifest:
75
74
  "In this case, set case_sensitive as True for those methods to distinguish them."
76
75
  )
77
76
 
77
+ dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
78
+ if options.get("include_pip_dependencies"):
79
+ warnings.warn(
80
+ "`include_pip_dependencies` specified as True: pip dependencies will be included and may not"
81
+ "be warehouse-compabible. The model may need to be run in SPCS.",
82
+ category=UserWarning,
83
+ stacklevel=1,
84
+ )
85
+ dependencies["pip"] = runtime_dict["dependencies"]["pip"]
86
+
78
87
  manifest_dict = model_manifest_schema.ModelManifestDict(
79
88
  manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
80
89
  runtimes={
@@ -82,9 +91,7 @@ class ModelManifest:
82
91
  language="PYTHON",
83
92
  version=runtime_to_use.runtime_env.python_version,
84
93
  imports=runtime_dict["imports"],
85
- dependencies=model_manifest_schema.ModelRuntimeDependenciesDict(
86
- conda=runtime_dict["dependencies"]["conda"]
87
- ),
94
+ dependencies=dependencies,
88
95
  )
89
96
  },
90
97
  methods=[
@@ -127,12 +134,13 @@ class ModelManifest:
127
134
  result = []
128
135
  if data_sources:
129
136
  for source in data_sources:
130
- result.append(
131
- model_manifest_schema.LineageSourceDict(
132
- # Currently, we only support lineage from Dataset.
133
- type=model_manifest_schema.LineageSourceTypes.DATASET.value,
134
- entity=source.fully_qualified_name,
135
- version=source.version,
137
+ if isinstance(source, data_source.DatasetInfo):
138
+ result.append(
139
+ model_manifest_schema.LineageSourceDict(
140
+ # Currently, we only support lineage from Dataset.
141
+ type=model_manifest_schema.LineageSourceTypes.DATASET.value,
142
+ entity=source.fully_qualified_name,
143
+ version=source.version,
144
+ )
136
145
  )
137
- )
138
146
  return result
@@ -18,7 +18,8 @@ class ModelMethodFunctionTypes(enum.Enum):
18
18
 
19
19
 
20
20
  class ModelRuntimeDependenciesDict(TypedDict):
21
- conda: Required[str]
21
+ conda: NotRequired[str]
22
+ pip: NotRequired[str]
22
23
 
23
24
 
24
25
  class ModelRuntimeDict(TypedDict):
@@ -33,9 +33,9 @@ class FunctionGenerator:
33
33
 
34
34
  def __init__(
35
35
  self,
36
- model_file_rel_path: pathlib.PurePosixPath,
36
+ model_dir_rel_path: pathlib.PurePosixPath,
37
37
  ) -> None:
38
- self.model_file_rel_path = model_file_rel_path
38
+ self.model_dir_rel_path = model_dir_rel_path
39
39
 
40
40
  def generate(
41
41
  self,
@@ -67,7 +67,7 @@ class FunctionGenerator:
67
67
  )
68
68
 
69
69
  udf_code = function_template.format(
70
- model_file_name=self.model_file_rel_path.name,
70
+ model_dir_name=self.model_dir_rel_path.name,
71
71
  target_method=target_method,
72
72
  max_batch_size=options.get("max_batch_size", None),
73
73
  function_name=FunctionGenerator.FUNCTION_NAME,
@@ -1,12 +1,7 @@
1
- import fcntl
2
1
  import functools
3
2
  import inspect
4
3
  import os
5
4
  import sys
6
- import threading
7
- import zipfile
8
- from types import TracebackType
9
- from typing import Optional, Type
10
5
 
11
6
  import anyio
12
7
  import pandas as pd
@@ -15,42 +10,18 @@ from _snowflake import vectorized
15
10
  from snowflake.ml.model._packager import model_packager
16
11
 
17
12
 
18
- class FileLock:
19
- def __enter__(self) -> None:
20
- self._lock = threading.Lock()
21
- self._lock.acquire()
22
- self._fd = open("/tmp/lockfile.LOCK", "w+")
23
- fcntl.lockf(self._fd, fcntl.LOCK_EX)
24
-
25
- def __exit__(
26
- self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
27
- ) -> None:
28
- self._fd.close()
29
- self._lock.release()
30
-
31
-
32
13
  # User-defined parameters
33
- MODEL_FILE_NAME = "{model_file_name}"
14
+ MODEL_DIR_REL_PATH = "{model_dir_name}"
34
15
  TARGET_METHOD = "{target_method}"
35
16
  MAX_BATCH_SIZE = {max_batch_size}
36
17
 
37
-
38
18
  # Retrieve the model
39
19
  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
40
20
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
41
-
42
- model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
43
- zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
44
- extracted = "/tmp/models"
45
- extracted_model_dir_path = os.path.join(extracted, model_dir_name)
46
-
47
- with FileLock():
48
- if not os.path.isdir(extracted_model_dir_path):
49
- with zipfile.ZipFile(zip_model_path, "r") as myzip:
50
- myzip.extractall(extracted_model_dir_path)
21
+ model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
51
22
 
52
23
  # Load the model
53
- pk = model_packager.ModelPackager(extracted_model_dir_path)
24
+ pk = model_packager.ModelPackager(model_dir_path)
54
25
  pk.load(as_custom_model=True)
55
26
  assert pk.model, "model is not loaded"
56
27
  assert pk.meta, "model metadata is not loaded"
@@ -15,42 +15,18 @@ from _snowflake import vectorized
15
15
  from snowflake.ml.model._packager import model_packager
16
16
 
17
17
 
18
- class FileLock:
19
- def __enter__(self) -> None:
20
- self._lock = threading.Lock()
21
- self._lock.acquire()
22
- self._fd = open("/tmp/lockfile.LOCK", "w+")
23
- fcntl.lockf(self._fd, fcntl.LOCK_EX)
24
-
25
- def __exit__(
26
- self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
27
- ) -> None:
28
- self._fd.close()
29
- self._lock.release()
30
-
31
-
32
18
  # User-defined parameters
33
- MODEL_FILE_NAME = "{model_file_name}"
19
+ MODEL_DIR_REL_PATH = "{model_dir_name}"
34
20
  TARGET_METHOD = "{target_method}"
35
21
  MAX_BATCH_SIZE = {max_batch_size}
36
22
 
37
-
38
23
  # Retrieve the model
39
24
  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
40
25
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
41
-
42
- model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
43
- zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
44
- extracted = "/tmp/models"
45
- extracted_model_dir_path = os.path.join(extracted, model_dir_name)
46
-
47
- with FileLock():
48
- if not os.path.isdir(extracted_model_dir_path):
49
- with zipfile.ZipFile(zip_model_path, "r") as myzip:
50
- myzip.extractall(extracted_model_dir_path)
26
+ model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
51
27
 
52
28
  # Load the model
53
- pk = model_packager.ModelPackager(extracted_model_dir_path)
29
+ pk = model_packager.ModelPackager(model_dir_path)
54
30
  pk.load(as_custom_model=True)
55
31
  assert pk.model, "model is not loaded"
56
32
  assert pk.meta, "model metadata is not loaded"
@@ -1,12 +1,7 @@
1
- import fcntl
2
1
  import functools
3
2
  import inspect
4
3
  import os
5
4
  import sys
6
- import threading
7
- import zipfile
8
- from types import TracebackType
9
- from typing import Optional, Type
10
5
 
11
6
  import anyio
12
7
  import pandas as pd
@@ -15,42 +10,18 @@ from _snowflake import vectorized
15
10
  from snowflake.ml.model._packager import model_packager
16
11
 
17
12
 
18
- class FileLock:
19
- def __enter__(self) -> None:
20
- self._lock = threading.Lock()
21
- self._lock.acquire()
22
- self._fd = open("/tmp/lockfile.LOCK", "w+")
23
- fcntl.lockf(self._fd, fcntl.LOCK_EX)
24
-
25
- def __exit__(
26
- self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
27
- ) -> None:
28
- self._fd.close()
29
- self._lock.release()
30
-
31
-
32
13
  # User-defined parameters
33
- MODEL_FILE_NAME = "{model_file_name}"
14
+ MODEL_DIR_REL_PATH = "{model_dir_name}"
34
15
  TARGET_METHOD = "{target_method}"
35
16
  MAX_BATCH_SIZE = {max_batch_size}
36
17
 
37
-
38
18
  # Retrieve the model
39
19
  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
40
20
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
41
-
42
- model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
43
- zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
44
- extracted = "/tmp/models"
45
- extracted_model_dir_path = os.path.join(extracted, model_dir_name)
46
-
47
- with FileLock():
48
- if not os.path.isdir(extracted_model_dir_path):
49
- with zipfile.ZipFile(zip_model_path, "r") as myzip:
50
- myzip.extractall(extracted_model_dir_path)
21
+ model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
51
22
 
52
23
  # Load the model
53
- pk = model_packager.ModelPackager(extracted_model_dir_path)
24
+ pk = model_packager.ModelPackager(model_dir_path)
54
25
  pk.load(as_custom_model=True)
55
26
  assert pk.model, "model is not loaded"
56
27
  assert pk.meta, "model metadata is not loaded"
@@ -26,11 +26,14 @@ class ModelMethodOptions(TypedDict):
26
26
  def get_model_method_options_from_options(
27
27
  options: type_hints.ModelSaveOption, target_method: str
28
28
  ) -> ModelMethodOptions:
29
+ default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
30
+ if options.get("enable_explainability", False) and target_method.startswith("explain"):
31
+ default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
29
32
  method_option = options.get("method_options", {}).get(target_method, {})
30
- global_function_type = options.get("function_type", model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value)
33
+ global_function_type = options.get("function_type", default_function_type)
31
34
  function_type = method_option.get("function_type", global_function_type)
32
35
  if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
33
- raise NotImplementedError
36
+ raise NotImplementedError(f"Function type {function_type} is not supported.")
34
37
 
35
38
  return ModelMethodOptions(
36
39
  case_sensitive=method_option.get("case_sensitive", False),