snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +67 -10
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +12 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
- snowflake/ml/data/_internal/ingestor_utils.py +58 -0
- snowflake/ml/data/data_connector.py +133 -0
- snowflake/ml/data/data_ingestor.py +28 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_reader.py +18 -118
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
- snowflake/ml/feature_store/examples/example_helper.py +240 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
- snowflake/ml/feature_store/feature_store.py +579 -53
- snowflake/ml/feature_store/feature_view.py +168 -5
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +11 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +24 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +11 -1
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +42 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +68 -0
- snowflake/ml/model/_packager/model_handlers/xgboost.py +59 -0
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +4 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -4
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +48 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +64 -42
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -126,9 +126,11 @@ class FeatureView(lineage_node.LineageNode):
|
|
126
126
|
name: str,
|
127
127
|
entities: List[Entity],
|
128
128
|
feature_df: DataFrame,
|
129
|
+
*,
|
129
130
|
timestamp_col: Optional[str] = None,
|
130
131
|
refresh_freq: Optional[str] = None,
|
131
132
|
desc: str = "",
|
133
|
+
warehouse: Optional[str] = None,
|
132
134
|
**_kwargs: Any,
|
133
135
|
) -> None:
|
134
136
|
"""
|
@@ -149,7 +151,33 @@ class FeatureView(lineage_node.LineageNode):
|
|
149
151
|
NOTE: If refresh_freq is not provided, then FeatureView will be registered as View on Snowflake backend
|
150
152
|
and there won't be extra storage cost.
|
151
153
|
desc: description of the FeatureView.
|
154
|
+
warehouse: warehouse to refresh feature view. Not needed for static feature view (refresh_freq is None).
|
155
|
+
For managed feature view, this warehouse will overwrite the default warehouse of Feature Store if it is
|
156
|
+
specified, otherwise the default warehouse will be used.
|
152
157
|
_kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
|
158
|
+
|
159
|
+
Example::
|
160
|
+
|
161
|
+
>>> fs = FeatureStore(...)
|
162
|
+
>>> # draft_fv is a local object that hasn't materiaized to Snowflake backend yet.
|
163
|
+
>>> feature_df = session.sql("select f_1, f_2 from source_table")
|
164
|
+
>>> draft_fv = FeatureView(
|
165
|
+
... name="my_fv",
|
166
|
+
... entities=[e1, e2],
|
167
|
+
... feature_df=feature_df,
|
168
|
+
... timestamp_col='TS', # optional
|
169
|
+
... refresh_freq='1d', # optional
|
170
|
+
... desc='A line about this feature view', # optional
|
171
|
+
... warehouse='WH' # optional, the warehouse used to refresh (managed) feature view
|
172
|
+
... )
|
173
|
+
>>> print(draft_fv.status)
|
174
|
+
FeatureViewStatus.DRAFT
|
175
|
+
<BLANKLINE>
|
176
|
+
>>> # registered_fv is a local object that maps to a Snowflake backend object.
|
177
|
+
>>> registered_fv = fs.register_feature_view(draft_fv, "v1")
|
178
|
+
>>> print(registered_fv.status)
|
179
|
+
FeatureViewStatus.ACTIVE
|
180
|
+
|
153
181
|
"""
|
154
182
|
|
155
183
|
self._name: SqlIdentifier = SqlIdentifier(name)
|
@@ -167,7 +195,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
167
195
|
self._refresh_freq: Optional[str] = refresh_freq
|
168
196
|
self._database: Optional[SqlIdentifier] = None
|
169
197
|
self._schema: Optional[SqlIdentifier] = None
|
170
|
-
self._warehouse: Optional[SqlIdentifier] = None
|
198
|
+
self._warehouse: Optional[SqlIdentifier] = SqlIdentifier(warehouse) if warehouse is not None else None
|
171
199
|
self._refresh_mode: Optional[str] = None
|
172
200
|
self._refresh_mode_reason: Optional[str] = None
|
173
201
|
self._owner: Optional[str] = None
|
@@ -185,6 +213,33 @@ class FeatureView(lineage_node.LineageNode):
|
|
185
213
|
|
186
214
|
Raises:
|
187
215
|
ValueError: if selected feature names is not found in the FeatureView.
|
216
|
+
|
217
|
+
Example::
|
218
|
+
|
219
|
+
>>> fs = FeatureStore(...)
|
220
|
+
>>> e = fs.get_entity('TRIP_ID')
|
221
|
+
>>> # feature_df contains 3 features and 1 entity
|
222
|
+
>>> feature_df = session.table(source_table).select(
|
223
|
+
... 'TRIPDURATION',
|
224
|
+
... 'START_STATION_LATITUDE',
|
225
|
+
... 'END_STATION_LONGITUDE',
|
226
|
+
... 'TRIP_ID'
|
227
|
+
... )
|
228
|
+
>>> darft_fv = FeatureView(name='F_TRIP', entities=[e], feature_df=feature_df)
|
229
|
+
>>> fv = fs.register_feature_view(darft_fv, version='1.0')
|
230
|
+
>>> # shows all 3 features
|
231
|
+
>>> fv.feature_names
|
232
|
+
['TRIPDURATION', 'START_STATION_LATITUDE', 'END_STATION_LONGITUDE']
|
233
|
+
<BLANKLINE>
|
234
|
+
>>> # slice a subset of features
|
235
|
+
>>> fv_slice = fv.slice(['TRIPDURATION', 'START_STATION_LATITUDE'])
|
236
|
+
>>> fv_slice.names
|
237
|
+
['TRIPDURATION', 'START_STATION_LATITUDE']
|
238
|
+
<BLANKLINE>
|
239
|
+
>>> # query the full set of features in original feature view
|
240
|
+
>>> fv_slice.feature_view_ref.feature_names
|
241
|
+
['TRIPDURATION', 'START_STATION_LATITUDE', 'END_STATION_LONGITUDE']
|
242
|
+
|
188
243
|
"""
|
189
244
|
|
190
245
|
res = []
|
@@ -196,14 +251,30 @@ class FeatureView(lineage_node.LineageNode):
|
|
196
251
|
return FeatureViewSlice(self, res)
|
197
252
|
|
198
253
|
def fully_qualified_name(self) -> str:
|
199
|
-
"""
|
200
|
-
|
254
|
+
"""
|
255
|
+
Returns the fully qualified name (<database_name>.<schema_name>.<feature_view_name>) for the
|
256
|
+
FeatureView in Snowflake.
|
201
257
|
|
202
258
|
Returns:
|
203
259
|
fully qualified name string.
|
204
260
|
|
205
261
|
Raises:
|
206
262
|
RuntimeError: if the FeatureView is not registered.
|
263
|
+
|
264
|
+
Example::
|
265
|
+
|
266
|
+
>>> fs = FeatureStore(...)
|
267
|
+
>>> e = fs.get_entity('TRIP_ID')
|
268
|
+
>>> feature_df = session.table(source_table).select(
|
269
|
+
... 'TRIPDURATION',
|
270
|
+
... 'START_STATION_LATITUDE',
|
271
|
+
... 'TRIP_ID'
|
272
|
+
... )
|
273
|
+
>>> darft_fv = FeatureView(name='F_TRIP', entities=[e], feature_df=feature_df)
|
274
|
+
>>> registered_fv = fs.register_feature_view(darft_fv, version='1.0')
|
275
|
+
>>> registered_fv.fully_qualified_name()
|
276
|
+
'MY_DB.MY_SCHEMA."F_TRIP$1.0"'
|
277
|
+
|
207
278
|
"""
|
208
279
|
if self.status == FeatureViewStatus.DRAFT or self.version is None:
|
209
280
|
raise RuntimeError(f"FeatureView {self.name} has not been registered.")
|
@@ -221,6 +292,22 @@ class FeatureView(lineage_node.LineageNode):
|
|
221
292
|
|
222
293
|
Raises:
|
223
294
|
ValueError: if feature name is not found in the FeatureView.
|
295
|
+
|
296
|
+
Example::
|
297
|
+
|
298
|
+
>>> fs = FeatureStore(...)
|
299
|
+
>>> e = fs.get_entity('TRIP_ID')
|
300
|
+
>>> feature_df = session.table(source_table).select('TRIPDURATION', 'START_STATION_LATITUDE', 'TRIP_ID')
|
301
|
+
>>> draft_fv = FeatureView(name='F_TRIP', entities=[e], feature_df=feature_df)
|
302
|
+
>>> draft_fv = draft_fv.attach_feature_desc({
|
303
|
+
... "TRIPDURATION": "Duration of a trip.",
|
304
|
+
... "START_STATION_LATITUDE": "Latitude of the start station."
|
305
|
+
... })
|
306
|
+
>>> registered_fv = fs.register_feature_view(draft_fv, version='1.0')
|
307
|
+
>>> registered_fv.feature_descs
|
308
|
+
OrderedDict([('TRIPDURATION', 'Duration of a trip.'),
|
309
|
+
('START_STATION_LATITUDE', 'Latitude of the start station.')])
|
310
|
+
|
224
311
|
"""
|
225
312
|
for f, d in descs.items():
|
226
313
|
f = SqlIdentifier(f)
|
@@ -254,6 +341,31 @@ class FeatureView(lineage_node.LineageNode):
|
|
254
341
|
|
255
342
|
@desc.setter
|
256
343
|
def desc(self, new_value: str) -> None:
|
344
|
+
"""Set the description of feature view.
|
345
|
+
|
346
|
+
Args:
|
347
|
+
new_value: new value of description.
|
348
|
+
|
349
|
+
Example::
|
350
|
+
|
351
|
+
>>> fs = FeatureStore(...)
|
352
|
+
>>> e = fs.get_entity('TRIP_ID')
|
353
|
+
>>> darft_fv = FeatureView(
|
354
|
+
... name='F_TRIP',
|
355
|
+
... entities=[e],
|
356
|
+
... feature_df=feature_df,
|
357
|
+
... desc='old desc'
|
358
|
+
... )
|
359
|
+
>>> fv_1 = fs.register_feature_view(darft_fv, version='1.0')
|
360
|
+
>>> print(fv_1.desc)
|
361
|
+
old desc
|
362
|
+
<BLANKLINE>
|
363
|
+
>>> darft_fv.desc = 'NEW DESC'
|
364
|
+
>>> fv_2 = fs.register_feature_view(darft_fv, version='2.0')
|
365
|
+
>>> print(fv_2.desc)
|
366
|
+
NEW DESC
|
367
|
+
|
368
|
+
"""
|
257
369
|
warnings.warn(
|
258
370
|
"You must call register_feature_view() to make it effective. "
|
259
371
|
"Or use update_feature_view(desc=<new_value>).",
|
@@ -288,6 +400,31 @@ class FeatureView(lineage_node.LineageNode):
|
|
288
400
|
|
289
401
|
@refresh_freq.setter
|
290
402
|
def refresh_freq(self, new_value: str) -> None:
|
403
|
+
"""Set refresh frequency of feature view.
|
404
|
+
|
405
|
+
Args:
|
406
|
+
new_value: The new value of refresh frequency.
|
407
|
+
|
408
|
+
Example::
|
409
|
+
|
410
|
+
>>> fs = FeatureStore(...)
|
411
|
+
>>> e = fs.get_entity('TRIP_ID')
|
412
|
+
>>> darft_fv = FeatureView(
|
413
|
+
... name='F_TRIP',
|
414
|
+
... entities=[e],
|
415
|
+
... feature_df=feature_df,
|
416
|
+
... refresh_freq='1d'
|
417
|
+
... )
|
418
|
+
>>> fv_1 = fs.register_feature_view(darft_fv, version='1.0')
|
419
|
+
>>> print(fv_1.refresh_freq)
|
420
|
+
1 day
|
421
|
+
<BLANKLINE>
|
422
|
+
>>> darft_fv.refresh_freq = '12h'
|
423
|
+
>>> fv_2 = fs.register_feature_view(darft_fv, version='2.0')
|
424
|
+
>>> print(fv_2.refresh_freq)
|
425
|
+
12 hours
|
426
|
+
|
427
|
+
"""
|
291
428
|
warnings.warn(
|
292
429
|
"You must call register_feature_view() to make it effective. "
|
293
430
|
"Or use update_feature_view(refresh_freq=<new_value>).",
|
@@ -310,6 +447,32 @@ class FeatureView(lineage_node.LineageNode):
|
|
310
447
|
|
311
448
|
@warehouse.setter
|
312
449
|
def warehouse(self, new_value: str) -> None:
|
450
|
+
"""Set warehouse of feature view.
|
451
|
+
|
452
|
+
Args:
|
453
|
+
new_value: The new value of warehouse.
|
454
|
+
|
455
|
+
Example::
|
456
|
+
|
457
|
+
>>> fs = FeatureStore(...)
|
458
|
+
>>> e = fs.get_entity('TRIP_ID')
|
459
|
+
>>> darft_fv = FeatureView(
|
460
|
+
... name='F_TRIP',
|
461
|
+
... entities=[e],
|
462
|
+
... feature_df=feature_df,
|
463
|
+
... refresh_freq='1d',
|
464
|
+
... warehouse='WH1',
|
465
|
+
... )
|
466
|
+
>>> fv_1 = fs.register_feature_view(darft_fv, version='1.0')
|
467
|
+
>>> print(fv_1.warehouse)
|
468
|
+
WH1
|
469
|
+
<BLANKLINE>
|
470
|
+
>>> darft_fv.warehouse = 'WH2'
|
471
|
+
>>> fv_2 = fs.register_feature_view(darft_fv, version='2.0')
|
472
|
+
>>> print(fv_2.warehouse)
|
473
|
+
WH2
|
474
|
+
|
475
|
+
"""
|
313
476
|
warnings.warn(
|
314
477
|
"You must call register_feature_view() to make it effective. "
|
315
478
|
"Or use update_feature_view(warehouse=<new_value>).",
|
@@ -456,7 +619,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
456
619
|
|
457
620
|
entities = []
|
458
621
|
for e_json in json_dict["_entities"]:
|
459
|
-
e = Entity(e_json["name"], e_json["join_keys"], e_json["desc"])
|
622
|
+
e = Entity(e_json["name"], e_json["join_keys"], desc=e_json["desc"])
|
460
623
|
e.owner = e_json["owner"]
|
461
624
|
entities.append(e)
|
462
625
|
|
@@ -504,7 +667,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
504
667
|
original_exception=ValueError("No active warehouse selected in the current session"),
|
505
668
|
)
|
506
669
|
|
507
|
-
fs = feature_store.FeatureStore(session, db_name, feature_store_name, session_warehouse)
|
670
|
+
fs = feature_store.FeatureStore(session, db_name, feature_store_name, default_warehouse=session_warehouse)
|
508
671
|
return fs.get_feature_view(feature_view_name, version) # type: ignore[no-any-return]
|
509
672
|
|
510
673
|
@staticmethod
|
snowflake/ml/fileset/stage_fs.py
CHANGED
@@ -234,21 +234,29 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
234
234
|
|
235
235
|
Raises:
|
236
236
|
SnowflakeMLException: An error occurred when the given path points to a file that cannot be found.
|
237
|
+
snowpark_exceptions.SnowparkClientException: File access failed with a Snowpark exception
|
237
238
|
"""
|
238
239
|
path = path.lstrip("/")
|
239
240
|
if self._USE_FALLBACK_FILE_ACCESS:
|
240
241
|
return self._open_with_snowpark(path)
|
241
242
|
cached_presigned_url = self._url_cache.get(path, None)
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
243
|
+
try:
|
244
|
+
if not cached_presigned_url:
|
245
|
+
res = self._fetch_presigned_urls([path])
|
246
|
+
url = res[0][1]
|
247
|
+
expire_at = time.time() + _PRESIGNED_URL_LIFETIME_SEC
|
248
|
+
cached_presigned_url = _PresignedUrl(url, expire_at)
|
249
|
+
self._url_cache[path] = cached_presigned_url
|
250
|
+
logging.debug(f"Retrieved presigned url for {path}.")
|
251
|
+
elif cached_presigned_url.is_expiring():
|
252
|
+
self.optimize_read()
|
253
|
+
cached_presigned_url = self._url_cache[path]
|
254
|
+
except snowpark_exceptions.SnowparkClientException as e:
|
255
|
+
if self._USE_FALLBACK_FILE_ACCESS == False: # noqa: E712 # Fallback disabled
|
256
|
+
raise
|
257
|
+
# This may be an intermittent failure, so don't set _USE_FALLBACK_FILE_ACCESS = True
|
258
|
+
logging.warning(f"Pre-signed URL generation failed with {e.message}, trying fallback file access")
|
259
|
+
return self._open_with_snowpark(path)
|
252
260
|
url = cached_presigned_url.url
|
253
261
|
try:
|
254
262
|
return self._fs._open(url, mode=mode, **kwargs)
|
@@ -118,7 +118,7 @@ class LineageNode:
|
|
118
118
|
)
|
119
119
|
domain = lineage_object["domain"].lower()
|
120
120
|
if domain_filter is None or domain in domain_filter:
|
121
|
-
if domain in DOMAIN_LINEAGE_REGISTRY:
|
121
|
+
if domain in DOMAIN_LINEAGE_REGISTRY and lineage_object["status"] == "ACTIVE":
|
122
122
|
lineage_nodes.append(
|
123
123
|
DOMAIN_LINEAGE_REGISTRY[domain]._load_from_lineage_node(
|
124
124
|
self._session, lineage_object["name"], lineage_object.get("version")
|
@@ -85,9 +85,8 @@ def _run_setup() -> None:
|
|
85
85
|
|
86
86
|
TARGET_METHOD = os.getenv("TARGET_METHOD")
|
87
87
|
|
88
|
-
_concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX",
|
89
|
-
|
90
|
-
_CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env) if _concurrent_requests_max_env else None
|
88
|
+
_concurrent_requests_max_env = os.getenv("_CONCURRENT_REQUESTS_MAX", "1")
|
89
|
+
_CONCURRENT_REQUESTS_MAX = int(_concurrent_requests_max_env)
|
91
90
|
|
92
91
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
93
92
|
if zipfile.is_zipfile(model_zip_stage_path):
|
@@ -11,7 +11,8 @@ from packaging import requirements
|
|
11
11
|
from typing_extensions import deprecated
|
12
12
|
|
13
13
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
14
|
-
from snowflake.ml._internal.lineage import
|
14
|
+
from snowflake.ml._internal.lineage import lineage_utils
|
15
|
+
from snowflake.ml.data import data_source
|
15
16
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
16
17
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
17
18
|
from snowflake.ml.model._packager import model_packager
|
@@ -128,16 +129,14 @@ class ModelComposer:
|
|
128
129
|
file_utils.copytree(
|
129
130
|
str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
|
130
131
|
)
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
data_sources=self._get_data_sources(model, sample_input_data),
|
140
|
-
)
|
132
|
+
self.manifest.save(
|
133
|
+
model_meta=self.packager.meta,
|
134
|
+
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
135
|
+
options=options,
|
136
|
+
data_sources=self._get_data_sources(model, sample_input_data),
|
137
|
+
)
|
138
|
+
else:
|
139
|
+
file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
|
141
140
|
|
142
141
|
file_utils.upload_directory_to_stage(
|
143
142
|
self.session,
|
@@ -186,6 +185,4 @@ class ModelComposer:
|
|
186
185
|
data_sources = lineage_utils.get_data_sources(model)
|
187
186
|
if not data_sources and sample_input_data is not None:
|
188
187
|
data_sources = lineage_utils.get_data_sources(sample_input_data)
|
189
|
-
|
190
|
-
return data_sources
|
191
|
-
return None
|
188
|
+
return data_sources
|
@@ -1,11 +1,12 @@
|
|
1
1
|
import collections
|
2
2
|
import copy
|
3
3
|
import pathlib
|
4
|
+
import warnings
|
4
5
|
from typing import List, Optional, cast
|
5
6
|
|
6
7
|
import yaml
|
7
8
|
|
8
|
-
from snowflake.ml.
|
9
|
+
from snowflake.ml.data import data_source
|
9
10
|
from snowflake.ml.model import type_hints
|
10
11
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
11
12
|
from snowflake.ml.model._model_composer.model_method import (
|
@@ -16,7 +17,6 @@ from snowflake.ml.model._packager.model_meta import (
|
|
16
17
|
model_meta as model_meta_api,
|
17
18
|
model_meta_schema,
|
18
19
|
)
|
19
|
-
from snowflake.snowpark import Session
|
20
20
|
|
21
21
|
|
22
22
|
class ModelManifest:
|
@@ -36,9 +36,8 @@ class ModelManifest:
|
|
36
36
|
|
37
37
|
def save(
|
38
38
|
self,
|
39
|
-
session: Session,
|
40
39
|
model_meta: model_meta_api.ModelMetadata,
|
41
|
-
|
40
|
+
model_rel_path: pathlib.PurePosixPath,
|
42
41
|
options: Optional[type_hints.ModelSaveOption] = None,
|
43
42
|
data_sources: Optional[List[data_source.DataSource]] = None,
|
44
43
|
) -> None:
|
@@ -47,10 +46,10 @@ class ModelManifest:
|
|
47
46
|
|
48
47
|
runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
|
49
48
|
runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
|
50
|
-
runtime_to_use.imports.append(
|
49
|
+
runtime_to_use.imports.append(str(model_rel_path) + "/")
|
51
50
|
runtime_dict = runtime_to_use.save(self.workspace_path)
|
52
51
|
|
53
|
-
self.function_generator = function_generator.FunctionGenerator(
|
52
|
+
self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
|
54
53
|
self.methods: List[model_method.ModelMethod] = []
|
55
54
|
for target_method in model_meta.signatures.keys():
|
56
55
|
method = model_method.ModelMethod(
|
@@ -75,6 +74,16 @@ class ModelManifest:
|
|
75
74
|
"In this case, set case_sensitive as True for those methods to distinguish them."
|
76
75
|
)
|
77
76
|
|
77
|
+
dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
|
78
|
+
if options.get("include_pip_dependencies"):
|
79
|
+
warnings.warn(
|
80
|
+
"`include_pip_dependencies` specified as True: pip dependencies will be included and may not"
|
81
|
+
"be warehouse-compabible. The model may need to be run in SPCS.",
|
82
|
+
category=UserWarning,
|
83
|
+
stacklevel=1,
|
84
|
+
)
|
85
|
+
dependencies["pip"] = runtime_dict["dependencies"]["pip"]
|
86
|
+
|
78
87
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
79
88
|
manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
|
80
89
|
runtimes={
|
@@ -82,9 +91,7 @@ class ModelManifest:
|
|
82
91
|
language="PYTHON",
|
83
92
|
version=runtime_to_use.runtime_env.python_version,
|
84
93
|
imports=runtime_dict["imports"],
|
85
|
-
dependencies=
|
86
|
-
conda=runtime_dict["dependencies"]["conda"]
|
87
|
-
),
|
94
|
+
dependencies=dependencies,
|
88
95
|
)
|
89
96
|
},
|
90
97
|
methods=[
|
@@ -127,12 +134,13 @@ class ModelManifest:
|
|
127
134
|
result = []
|
128
135
|
if data_sources:
|
129
136
|
for source in data_sources:
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
137
|
+
if isinstance(source, data_source.DatasetInfo):
|
138
|
+
result.append(
|
139
|
+
model_manifest_schema.LineageSourceDict(
|
140
|
+
# Currently, we only support lineage from Dataset.
|
141
|
+
type=model_manifest_schema.LineageSourceTypes.DATASET.value,
|
142
|
+
entity=source.fully_qualified_name,
|
143
|
+
version=source.version,
|
144
|
+
)
|
136
145
|
)
|
137
|
-
)
|
138
146
|
return result
|
@@ -33,9 +33,9 @@ class FunctionGenerator:
|
|
33
33
|
|
34
34
|
def __init__(
|
35
35
|
self,
|
36
|
-
|
36
|
+
model_dir_rel_path: pathlib.PurePosixPath,
|
37
37
|
) -> None:
|
38
|
-
self.
|
38
|
+
self.model_dir_rel_path = model_dir_rel_path
|
39
39
|
|
40
40
|
def generate(
|
41
41
|
self,
|
@@ -67,7 +67,7 @@ class FunctionGenerator:
|
|
67
67
|
)
|
68
68
|
|
69
69
|
udf_code = function_template.format(
|
70
|
-
|
70
|
+
model_dir_name=self.model_dir_rel_path.name,
|
71
71
|
target_method=target_method,
|
72
72
|
max_batch_size=options.get("max_batch_size", None),
|
73
73
|
function_name=FunctionGenerator.FUNCTION_NAME,
|
@@ -1,12 +1,7 @@
|
|
1
|
-
import fcntl
|
2
1
|
import functools
|
3
2
|
import inspect
|
4
3
|
import os
|
5
4
|
import sys
|
6
|
-
import threading
|
7
|
-
import zipfile
|
8
|
-
from types import TracebackType
|
9
|
-
from typing import Optional, Type
|
10
5
|
|
11
6
|
import anyio
|
12
7
|
import pandas as pd
|
@@ -15,42 +10,18 @@ from _snowflake import vectorized
|
|
15
10
|
from snowflake.ml.model._packager import model_packager
|
16
11
|
|
17
12
|
|
18
|
-
class FileLock:
|
19
|
-
def __enter__(self) -> None:
|
20
|
-
self._lock = threading.Lock()
|
21
|
-
self._lock.acquire()
|
22
|
-
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
23
|
-
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
24
|
-
|
25
|
-
def __exit__(
|
26
|
-
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
27
|
-
) -> None:
|
28
|
-
self._fd.close()
|
29
|
-
self._lock.release()
|
30
|
-
|
31
|
-
|
32
13
|
# User-defined parameters
|
33
|
-
|
14
|
+
MODEL_DIR_REL_PATH = "{model_dir_name}"
|
34
15
|
TARGET_METHOD = "{target_method}"
|
35
16
|
MAX_BATCH_SIZE = {max_batch_size}
|
36
17
|
|
37
|
-
|
38
18
|
# Retrieve the model
|
39
19
|
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
40
20
|
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
41
|
-
|
42
|
-
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
43
|
-
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
44
|
-
extracted = "/tmp/models"
|
45
|
-
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
46
|
-
|
47
|
-
with FileLock():
|
48
|
-
if not os.path.isdir(extracted_model_dir_path):
|
49
|
-
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
50
|
-
myzip.extractall(extracted_model_dir_path)
|
21
|
+
model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
|
51
22
|
|
52
23
|
# Load the model
|
53
|
-
pk = model_packager.ModelPackager(
|
24
|
+
pk = model_packager.ModelPackager(model_dir_path)
|
54
25
|
pk.load(as_custom_model=True)
|
55
26
|
assert pk.model, "model is not loaded"
|
56
27
|
assert pk.meta, "model metadata is not loaded"
|
@@ -15,42 +15,18 @@ from _snowflake import vectorized
|
|
15
15
|
from snowflake.ml.model._packager import model_packager
|
16
16
|
|
17
17
|
|
18
|
-
class FileLock:
|
19
|
-
def __enter__(self) -> None:
|
20
|
-
self._lock = threading.Lock()
|
21
|
-
self._lock.acquire()
|
22
|
-
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
23
|
-
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
24
|
-
|
25
|
-
def __exit__(
|
26
|
-
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
27
|
-
) -> None:
|
28
|
-
self._fd.close()
|
29
|
-
self._lock.release()
|
30
|
-
|
31
|
-
|
32
18
|
# User-defined parameters
|
33
|
-
|
19
|
+
MODEL_DIR_REL_PATH = "{model_dir_name}"
|
34
20
|
TARGET_METHOD = "{target_method}"
|
35
21
|
MAX_BATCH_SIZE = {max_batch_size}
|
36
22
|
|
37
|
-
|
38
23
|
# Retrieve the model
|
39
24
|
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
40
25
|
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
41
|
-
|
42
|
-
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
43
|
-
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
44
|
-
extracted = "/tmp/models"
|
45
|
-
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
46
|
-
|
47
|
-
with FileLock():
|
48
|
-
if not os.path.isdir(extracted_model_dir_path):
|
49
|
-
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
50
|
-
myzip.extractall(extracted_model_dir_path)
|
26
|
+
model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
|
51
27
|
|
52
28
|
# Load the model
|
53
|
-
pk = model_packager.ModelPackager(
|
29
|
+
pk = model_packager.ModelPackager(model_dir_path)
|
54
30
|
pk.load(as_custom_model=True)
|
55
31
|
assert pk.model, "model is not loaded"
|
56
32
|
assert pk.meta, "model metadata is not loaded"
|
@@ -1,12 +1,7 @@
|
|
1
|
-
import fcntl
|
2
1
|
import functools
|
3
2
|
import inspect
|
4
3
|
import os
|
5
4
|
import sys
|
6
|
-
import threading
|
7
|
-
import zipfile
|
8
|
-
from types import TracebackType
|
9
|
-
from typing import Optional, Type
|
10
5
|
|
11
6
|
import anyio
|
12
7
|
import pandas as pd
|
@@ -15,42 +10,18 @@ from _snowflake import vectorized
|
|
15
10
|
from snowflake.ml.model._packager import model_packager
|
16
11
|
|
17
12
|
|
18
|
-
class FileLock:
|
19
|
-
def __enter__(self) -> None:
|
20
|
-
self._lock = threading.Lock()
|
21
|
-
self._lock.acquire()
|
22
|
-
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
23
|
-
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
24
|
-
|
25
|
-
def __exit__(
|
26
|
-
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
27
|
-
) -> None:
|
28
|
-
self._fd.close()
|
29
|
-
self._lock.release()
|
30
|
-
|
31
|
-
|
32
13
|
# User-defined parameters
|
33
|
-
|
14
|
+
MODEL_DIR_REL_PATH = "{model_dir_name}"
|
34
15
|
TARGET_METHOD = "{target_method}"
|
35
16
|
MAX_BATCH_SIZE = {max_batch_size}
|
36
17
|
|
37
|
-
|
38
18
|
# Retrieve the model
|
39
19
|
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
40
20
|
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
41
|
-
|
42
|
-
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
43
|
-
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
44
|
-
extracted = "/tmp/models"
|
45
|
-
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
46
|
-
|
47
|
-
with FileLock():
|
48
|
-
if not os.path.isdir(extracted_model_dir_path):
|
49
|
-
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
50
|
-
myzip.extractall(extracted_model_dir_path)
|
21
|
+
model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
|
51
22
|
|
52
23
|
# Load the model
|
53
|
-
pk = model_packager.ModelPackager(
|
24
|
+
pk = model_packager.ModelPackager(model_dir_path)
|
54
25
|
pk.load(as_custom_model=True)
|
55
26
|
assert pk.model, "model is not loaded"
|
56
27
|
assert pk.meta, "model metadata is not loaded"
|
@@ -26,11 +26,14 @@ class ModelMethodOptions(TypedDict):
|
|
26
26
|
def get_model_method_options_from_options(
|
27
27
|
options: type_hints.ModelSaveOption, target_method: str
|
28
28
|
) -> ModelMethodOptions:
|
29
|
+
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
30
|
+
if options.get("enable_explainability", False) and target_method.startswith("explain"):
|
31
|
+
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
29
32
|
method_option = options.get("method_options", {}).get(target_method, {})
|
30
|
-
global_function_type = options.get("function_type",
|
33
|
+
global_function_type = options.get("function_type", default_function_type)
|
31
34
|
function_type = method_option.get("function_type", global_function_type)
|
32
35
|
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|
33
|
-
raise NotImplementedError
|
36
|
+
raise NotImplementedError(f"Function type {function_type} is not supported.")
|
34
37
|
|
35
38
|
return ModelMethodOptions(
|
36
39
|
case_sensitive=method_option.get("case_sensitive", False),
|