snowflake-ml-python 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +6 -4
- snowflake/ml/_internal/utils/service_logger.py +118 -4
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
- snowflake/ml/data/data_connector.py +4 -34
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/dataset/dataset_reader.py +2 -8
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +55 -0
- snowflake/ml/experiment/callback/xgboost.py +63 -0
- snowflake/ml/experiment/utils.py +14 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +159 -52
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +126 -23
- snowflake/ml/jobs/_utils/spec_utils.py +1 -1
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +22 -6
- snowflake/ml/jobs/manager.py +5 -3
- snowflake/ml/model/_client/model/model_version_impl.py +56 -48
- snowflake/ml/model/_client/ops/service_ops.py +194 -14
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +87 -18
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/models/huggingface_pipeline.py +71 -49
- snowflake/ml/model/type_hints.py +26 -1
- snowflake/ml/registry/_manager/model_manager.py +30 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
- snowflake/ml/registry/registry.py +0 -19
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +542 -491
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +39 -34
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
|
@@ -96,11 +96,13 @@ class ServiceLogMetadata:
|
|
|
96
96
|
msg: str,
|
|
97
97
|
is_model_build_service_done: bool,
|
|
98
98
|
is_model_logger_service_done: bool,
|
|
99
|
+
operation_id: str,
|
|
99
100
|
propagate: bool = False,
|
|
100
101
|
) -> None:
|
|
101
102
|
to_service_logger = service_logger.get_logger(
|
|
102
103
|
f"{to_service.display_service_name}-{to_service.instance_id}",
|
|
103
104
|
to_service.log_color,
|
|
105
|
+
operation_id=operation_id,
|
|
104
106
|
)
|
|
105
107
|
to_service_logger.propagate = propagate
|
|
106
108
|
self.service_logger = to_service_logger
|
|
@@ -178,9 +180,7 @@ class ServiceOperator:
|
|
|
178
180
|
service_name: sql_identifier.SqlIdentifier,
|
|
179
181
|
image_build_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
180
182
|
service_compute_pool_name: sql_identifier.SqlIdentifier,
|
|
181
|
-
|
|
182
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
183
|
-
image_repo_name: sql_identifier.SqlIdentifier,
|
|
183
|
+
image_repo: str,
|
|
184
184
|
ingress_enabled: bool,
|
|
185
185
|
max_instances: int,
|
|
186
186
|
cpu_requests: Optional[str],
|
|
@@ -191,11 +191,15 @@ class ServiceOperator:
|
|
|
191
191
|
force_rebuild: bool,
|
|
192
192
|
build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
|
|
193
193
|
block: bool,
|
|
194
|
+
progress_status: type_hints.ProgressStatus,
|
|
194
195
|
statement_params: Optional[dict[str, Any]] = None,
|
|
195
196
|
# hf model
|
|
196
197
|
hf_model_args: Optional[HFModelArgs] = None,
|
|
197
198
|
) -> Union[str, async_job.AsyncJob]:
|
|
198
199
|
|
|
200
|
+
# Generate operation ID for this deployment
|
|
201
|
+
operation_id = service_logger.get_operation_id()
|
|
202
|
+
|
|
199
203
|
# Fall back to the registry's database and schema if not provided
|
|
200
204
|
database_name = database_name or self._database_name
|
|
201
205
|
schema_name = schema_name or self._schema_name
|
|
@@ -204,8 +208,17 @@ class ServiceOperator:
|
|
|
204
208
|
service_database_name = service_database_name or database_name or self._database_name
|
|
205
209
|
service_schema_name = service_schema_name or schema_name or self._schema_name
|
|
206
210
|
|
|
211
|
+
# Parse image repo
|
|
212
|
+
image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
|
|
213
|
+
image_repo
|
|
214
|
+
)
|
|
207
215
|
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
|
208
216
|
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
217
|
+
|
|
218
|
+
# Step 1: Preparing deployment artifacts
|
|
219
|
+
progress_status.update("preparing deployment artifacts...")
|
|
220
|
+
progress_status.increment()
|
|
221
|
+
|
|
209
222
|
if self._workspace:
|
|
210
223
|
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
211
224
|
else:
|
|
@@ -254,6 +267,11 @@ class ServiceOperator:
|
|
|
254
267
|
**(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
|
|
255
268
|
)
|
|
256
269
|
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
270
|
+
|
|
271
|
+
# Step 2: Uploading deployment artifacts
|
|
272
|
+
progress_status.update("uploading deployment artifacts...")
|
|
273
|
+
progress_status.increment()
|
|
274
|
+
|
|
257
275
|
if self._workspace:
|
|
258
276
|
assert stage_path is not None
|
|
259
277
|
file_utils.upload_directory_to_stage(
|
|
@@ -276,6 +294,10 @@ class ServiceOperator:
|
|
|
276
294
|
statement_params=statement_params,
|
|
277
295
|
)
|
|
278
296
|
|
|
297
|
+
# Step 3: Initiating model deployment
|
|
298
|
+
progress_status.update("initiating model deployment...")
|
|
299
|
+
progress_status.increment()
|
|
300
|
+
|
|
279
301
|
# deploy the model service
|
|
280
302
|
query_id, async_job = self._service_client.deploy_model(
|
|
281
303
|
stage_path=stage_path if self._workspace else None,
|
|
@@ -327,17 +349,68 @@ class ServiceOperator:
|
|
|
327
349
|
model_inference_service=model_inference_service,
|
|
328
350
|
model_inference_service_exists=model_inference_service_exists,
|
|
329
351
|
force_rebuild=force_rebuild,
|
|
352
|
+
operation_id=operation_id,
|
|
330
353
|
statement_params=statement_params,
|
|
331
354
|
)
|
|
332
355
|
|
|
333
356
|
if block:
|
|
334
|
-
|
|
357
|
+
try:
|
|
358
|
+
# Step 4: Starting model build: waits for build to start
|
|
359
|
+
progress_status.update("starting model image build...")
|
|
360
|
+
progress_status.increment()
|
|
361
|
+
|
|
362
|
+
# Poll for model build to start if not using existing service
|
|
363
|
+
if not model_inference_service_exists:
|
|
364
|
+
self._wait_for_service_status(
|
|
365
|
+
model_build_service_name,
|
|
366
|
+
service_sql.ServiceStatus.RUNNING,
|
|
367
|
+
service_database_name,
|
|
368
|
+
service_schema_name,
|
|
369
|
+
async_job,
|
|
370
|
+
statement_params,
|
|
371
|
+
)
|
|
335
372
|
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
373
|
+
# Step 5: Building model image
|
|
374
|
+
progress_status.update("building model image...")
|
|
375
|
+
progress_status.increment()
|
|
376
|
+
|
|
377
|
+
# Poll for model build completion
|
|
378
|
+
if not model_inference_service_exists:
|
|
379
|
+
self._wait_for_service_status(
|
|
380
|
+
model_build_service_name,
|
|
381
|
+
service_sql.ServiceStatus.DONE,
|
|
382
|
+
service_database_name,
|
|
383
|
+
service_schema_name,
|
|
384
|
+
async_job,
|
|
385
|
+
statement_params,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# Step 6: Deploying model service (push complete, starting inference service)
|
|
389
|
+
progress_status.update("deploying model service...")
|
|
390
|
+
progress_status.increment()
|
|
391
|
+
|
|
392
|
+
log_thread.join()
|
|
393
|
+
|
|
394
|
+
res = cast(str, cast(list[row.Row], async_job.result())[0][0])
|
|
395
|
+
module_logger.info(f"Inference service {service_name} deployment complete: {res}")
|
|
396
|
+
return res
|
|
397
|
+
|
|
398
|
+
except RuntimeError as e:
|
|
399
|
+
# Handle service creation/deployment failures
|
|
400
|
+
error_msg = f"Model service deployment failed: {str(e)}"
|
|
401
|
+
module_logger.error(error_msg)
|
|
402
|
+
|
|
403
|
+
# Update progress status to show failure
|
|
404
|
+
progress_status.update(error_msg, state="error")
|
|
405
|
+
|
|
406
|
+
# Stop the log thread if it's running
|
|
407
|
+
if "log_thread" in locals() and log_thread.is_alive():
|
|
408
|
+
log_thread.join(timeout=5) # Give it a few seconds to finish gracefully
|
|
409
|
+
|
|
410
|
+
# Re-raise the exception to propagate the error
|
|
411
|
+
raise RuntimeError(error_msg) from e
|
|
412
|
+
|
|
413
|
+
return async_job
|
|
341
414
|
|
|
342
415
|
def _start_service_log_streaming(
|
|
343
416
|
self,
|
|
@@ -347,6 +420,7 @@ class ServiceOperator:
|
|
|
347
420
|
model_inference_service: ServiceLogInfo,
|
|
348
421
|
model_inference_service_exists: bool,
|
|
349
422
|
force_rebuild: bool,
|
|
423
|
+
operation_id: str,
|
|
350
424
|
statement_params: Optional[dict[str, Any]] = None,
|
|
351
425
|
) -> threading.Thread:
|
|
352
426
|
"""Start the service log streaming in a separate thread."""
|
|
@@ -360,6 +434,7 @@ class ServiceOperator:
|
|
|
360
434
|
model_inference_service,
|
|
361
435
|
model_inference_service_exists,
|
|
362
436
|
force_rebuild,
|
|
437
|
+
operation_id,
|
|
363
438
|
statement_params,
|
|
364
439
|
),
|
|
365
440
|
)
|
|
@@ -372,6 +447,7 @@ class ServiceOperator:
|
|
|
372
447
|
service_log_meta: ServiceLogMetadata,
|
|
373
448
|
model_build_service: ServiceLogInfo,
|
|
374
449
|
model_inference_service: ServiceLogInfo,
|
|
450
|
+
operation_id: str,
|
|
375
451
|
statement_params: Optional[dict[str, Any]] = None,
|
|
376
452
|
) -> None:
|
|
377
453
|
"""Helper function to fetch logs and update the service log metadata if needed.
|
|
@@ -386,6 +462,7 @@ class ServiceOperator:
|
|
|
386
462
|
service_log_meta: The ServiceLogMetadata holds the state of the service log metadata.
|
|
387
463
|
model_build_service: The ServiceLogInfo for the model build service.
|
|
388
464
|
model_inference_service: The ServiceLogInfo for the model inference service.
|
|
465
|
+
operation_id: The operation ID for the service, e.g. "model_deploy_a1b2c3d4_1703875200"
|
|
389
466
|
statement_params: The statement parameters to use for the service client.
|
|
390
467
|
"""
|
|
391
468
|
|
|
@@ -415,6 +492,7 @@ class ServiceOperator:
|
|
|
415
492
|
"Model build is not rebuilding the inference image, but using a previously built image.",
|
|
416
493
|
is_model_build_service_done=True,
|
|
417
494
|
is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
|
|
495
|
+
operation_id=operation_id,
|
|
418
496
|
)
|
|
419
497
|
|
|
420
498
|
try:
|
|
@@ -488,6 +566,7 @@ class ServiceOperator:
|
|
|
488
566
|
f"Model Logger service {service.display_service_name} complete.",
|
|
489
567
|
is_model_build_service_done=False,
|
|
490
568
|
is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
|
|
569
|
+
operation_id=operation_id,
|
|
491
570
|
)
|
|
492
571
|
# check if model build service is done
|
|
493
572
|
# and transition the service log metadata to the model inference service
|
|
@@ -497,6 +576,7 @@ class ServiceOperator:
|
|
|
497
576
|
f"Image build service {service.display_service_name} complete.",
|
|
498
577
|
is_model_build_service_done=True,
|
|
499
578
|
is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
|
|
579
|
+
operation_id=operation_id,
|
|
500
580
|
)
|
|
501
581
|
else:
|
|
502
582
|
module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
|
|
@@ -509,6 +589,7 @@ class ServiceOperator:
|
|
|
509
589
|
model_inference_service: ServiceLogInfo,
|
|
510
590
|
model_inference_service_exists: bool,
|
|
511
591
|
force_rebuild: bool,
|
|
592
|
+
operation_id: str,
|
|
512
593
|
statement_params: Optional[dict[str, Any]] = None,
|
|
513
594
|
) -> None:
|
|
514
595
|
"""Stream service logs while the async job is running."""
|
|
@@ -516,14 +597,14 @@ class ServiceOperator:
|
|
|
516
597
|
model_build_service_logger = service_logger.get_logger( # BuildJobName
|
|
517
598
|
model_build_service.display_service_name,
|
|
518
599
|
model_build_service.log_color,
|
|
600
|
+
operation_id=operation_id,
|
|
519
601
|
)
|
|
520
|
-
model_build_service_logger.propagate = False
|
|
521
602
|
if model_logger_service:
|
|
522
603
|
model_logger_service_logger = service_logger.get_logger( # ModelLoggerName
|
|
523
604
|
model_logger_service.display_service_name,
|
|
524
605
|
model_logger_service.log_color,
|
|
606
|
+
operation_id=operation_id,
|
|
525
607
|
)
|
|
526
|
-
model_logger_service_logger.propagate = False
|
|
527
608
|
|
|
528
609
|
service_log_meta = ServiceLogMetadata(
|
|
529
610
|
service_logger=model_logger_service_logger,
|
|
@@ -557,6 +638,7 @@ class ServiceOperator:
|
|
|
557
638
|
force_rebuild=force_rebuild,
|
|
558
639
|
model_build_service=model_build_service,
|
|
559
640
|
model_inference_service=model_inference_service,
|
|
641
|
+
operation_id=operation_id,
|
|
560
642
|
statement_params=statement_params,
|
|
561
643
|
)
|
|
562
644
|
except Exception as ex:
|
|
@@ -564,6 +646,7 @@ class ServiceOperator:
|
|
|
564
646
|
is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException)
|
|
565
647
|
contains_msg = any(msg in str(ex) for msg in ["Pending scheduling", "Waiting to start"])
|
|
566
648
|
matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None
|
|
649
|
+
|
|
567
650
|
if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)):
|
|
568
651
|
module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
|
569
652
|
time.sleep(5)
|
|
@@ -603,6 +686,101 @@ class ServiceOperator:
|
|
|
603
686
|
except Exception as ex:
|
|
604
687
|
module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
|
|
605
688
|
|
|
689
|
+
def _wait_for_service_status(
|
|
690
|
+
self,
|
|
691
|
+
service_name: sql_identifier.SqlIdentifier,
|
|
692
|
+
target_status: service_sql.ServiceStatus,
|
|
693
|
+
service_database_name: Optional[sql_identifier.SqlIdentifier],
|
|
694
|
+
service_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
695
|
+
async_job: snowpark.AsyncJob,
|
|
696
|
+
statement_params: Optional[dict[str, Any]] = None,
|
|
697
|
+
timeout_minutes: int = 30,
|
|
698
|
+
) -> None:
|
|
699
|
+
"""Wait for service to reach the specified status while monitoring async job for failures.
|
|
700
|
+
|
|
701
|
+
Args:
|
|
702
|
+
service_name: The service to monitor
|
|
703
|
+
target_status: The target status to wait for
|
|
704
|
+
service_database_name: Database containing the service
|
|
705
|
+
service_schema_name: Schema containing the service
|
|
706
|
+
async_job: The async job to monitor for completion/failure
|
|
707
|
+
statement_params: SQL statement parameters
|
|
708
|
+
timeout_minutes: Maximum time to wait before timing out
|
|
709
|
+
|
|
710
|
+
Raises:
|
|
711
|
+
RuntimeError: If service fails, times out, or enters an error state
|
|
712
|
+
"""
|
|
713
|
+
start_time = time.time()
|
|
714
|
+
timeout_seconds = timeout_minutes * 60
|
|
715
|
+
service_seen_before = False
|
|
716
|
+
|
|
717
|
+
while True:
|
|
718
|
+
# Check if async job has failed (but don't return on success - we need specific service status)
|
|
719
|
+
if async_job.is_done():
|
|
720
|
+
try:
|
|
721
|
+
async_job.result()
|
|
722
|
+
# Async job completed successfully, but we're waiting for a specific service status
|
|
723
|
+
# This might mean the service completed and was cleaned up
|
|
724
|
+
module_logger.debug(
|
|
725
|
+
f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
|
|
726
|
+
)
|
|
727
|
+
except Exception as e:
|
|
728
|
+
raise RuntimeError(f"Service deployment failed: {e}")
|
|
729
|
+
|
|
730
|
+
try:
|
|
731
|
+
statuses = self._service_client.get_service_container_statuses(
|
|
732
|
+
database_name=service_database_name,
|
|
733
|
+
schema_name=service_schema_name,
|
|
734
|
+
service_name=service_name,
|
|
735
|
+
include_message=True,
|
|
736
|
+
statement_params=statement_params,
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
if statuses:
|
|
740
|
+
service_seen_before = True
|
|
741
|
+
current_status = statuses[0].service_status
|
|
742
|
+
|
|
743
|
+
# Check if we've reached the target status
|
|
744
|
+
if current_status == target_status:
|
|
745
|
+
return
|
|
746
|
+
|
|
747
|
+
# Check for failure states
|
|
748
|
+
if current_status in [service_sql.ServiceStatus.FAILED, service_sql.ServiceStatus.INTERNAL_ERROR]:
|
|
749
|
+
error_msg = f"Service {service_name} failed with status {current_status.value}"
|
|
750
|
+
if statuses[0].message:
|
|
751
|
+
error_msg += f": {statuses[0].message}"
|
|
752
|
+
raise RuntimeError(error_msg)
|
|
753
|
+
|
|
754
|
+
except exceptions.SnowparkSQLException as e:
|
|
755
|
+
# Service might not exist yet - this is expected during initial deployment
|
|
756
|
+
if "does not exist" in str(e) or "002003" in str(e):
|
|
757
|
+
# If we're waiting for DONE status and we've seen the service before,
|
|
758
|
+
# it likely completed and was cleaned up
|
|
759
|
+
if target_status == service_sql.ServiceStatus.DONE and service_seen_before:
|
|
760
|
+
module_logger.debug(
|
|
761
|
+
f"Service {service_name} disappeared after being seen, "
|
|
762
|
+
f"assuming it reached {target_status.value} and was cleaned up"
|
|
763
|
+
)
|
|
764
|
+
return
|
|
765
|
+
|
|
766
|
+
module_logger.debug(f"Service {service_name} not found yet, continuing to wait...")
|
|
767
|
+
else:
|
|
768
|
+
# Re-raise unexpected SQL exceptions
|
|
769
|
+
raise RuntimeError(f"Error checking service status: {e}")
|
|
770
|
+
except Exception as e:
|
|
771
|
+
# Re-raise unexpected exceptions instead of masking them
|
|
772
|
+
raise RuntimeError(f"Unexpected error while waiting for service status: {e}")
|
|
773
|
+
|
|
774
|
+
# Check timeout
|
|
775
|
+
elapsed_time = time.time() - start_time
|
|
776
|
+
if elapsed_time > timeout_seconds:
|
|
777
|
+
raise RuntimeError(
|
|
778
|
+
f"Timeout waiting for service {service_name} to reach status {target_status.value} "
|
|
779
|
+
f"after {timeout_minutes} minutes"
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
time.sleep(2) # Poll every 2 seconds
|
|
783
|
+
|
|
606
784
|
@staticmethod
|
|
607
785
|
def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
|
|
608
786
|
"""Get the service ID through the server-side logic."""
|
|
@@ -660,9 +838,7 @@ class ServiceOperator:
|
|
|
660
838
|
job_name: sql_identifier.SqlIdentifier,
|
|
661
839
|
compute_pool_name: sql_identifier.SqlIdentifier,
|
|
662
840
|
warehouse_name: sql_identifier.SqlIdentifier,
|
|
663
|
-
|
|
664
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
665
|
-
image_repo_name: sql_identifier.SqlIdentifier,
|
|
841
|
+
image_repo: str,
|
|
666
842
|
output_table_database_name: Optional[sql_identifier.SqlIdentifier],
|
|
667
843
|
output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
668
844
|
output_table_name: sql_identifier.SqlIdentifier,
|
|
@@ -683,6 +859,10 @@ class ServiceOperator:
|
|
|
683
859
|
job_database_name = job_database_name or database_name or self._database_name
|
|
684
860
|
job_schema_name = job_schema_name or schema_name or self._schema_name
|
|
685
861
|
|
|
862
|
+
# Parse image repo
|
|
863
|
+
image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
|
|
864
|
+
image_repo
|
|
865
|
+
)
|
|
686
866
|
image_repo_database_name = image_repo_database_name or database_name or self._database_name
|
|
687
867
|
image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
|
|
688
868
|
|
|
@@ -2,7 +2,7 @@ import dataclasses
|
|
|
2
2
|
import enum
|
|
3
3
|
import logging
|
|
4
4
|
import textwrap
|
|
5
|
-
from typing import Any, Optional
|
|
5
|
+
from typing import Any, Optional
|
|
6
6
|
|
|
7
7
|
from snowflake import snowpark
|
|
8
8
|
from snowflake.ml._internal.utils import (
|
|
@@ -69,43 +69,6 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
|
69
69
|
CONTAINER_STATUS = "status"
|
|
70
70
|
MESSAGE = "message"
|
|
71
71
|
|
|
72
|
-
def build_model_container(
|
|
73
|
-
self,
|
|
74
|
-
*,
|
|
75
|
-
database_name: Optional[sql_identifier.SqlIdentifier],
|
|
76
|
-
schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
77
|
-
model_name: sql_identifier.SqlIdentifier,
|
|
78
|
-
version_name: sql_identifier.SqlIdentifier,
|
|
79
|
-
compute_pool_name: sql_identifier.SqlIdentifier,
|
|
80
|
-
image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
|
|
81
|
-
image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
82
|
-
image_repo_name: sql_identifier.SqlIdentifier,
|
|
83
|
-
gpu: Optional[Union[str, int]],
|
|
84
|
-
force_rebuild: bool,
|
|
85
|
-
external_access_integration: sql_identifier.SqlIdentifier,
|
|
86
|
-
statement_params: Optional[dict[str, Any]] = None,
|
|
87
|
-
) -> None:
|
|
88
|
-
actual_image_repo_database = image_repo_database_name or self._database_name
|
|
89
|
-
actual_image_repo_schema = image_repo_schema_name or self._schema_name
|
|
90
|
-
actual_model_database = database_name or self._database_name
|
|
91
|
-
actual_model_schema = schema_name or self._schema_name
|
|
92
|
-
fq_model_name = self.fully_qualified_object_name(actual_model_database, actual_model_schema, model_name)
|
|
93
|
-
fq_image_repo_name = identifier.get_schema_level_object_identifier(
|
|
94
|
-
actual_image_repo_database.identifier(),
|
|
95
|
-
actual_image_repo_schema.identifier(),
|
|
96
|
-
image_repo_name.identifier(),
|
|
97
|
-
)
|
|
98
|
-
is_gpu_str = "TRUE" if gpu else "FALSE"
|
|
99
|
-
force_rebuild_str = "TRUE" if force_rebuild else "FALSE"
|
|
100
|
-
query_result_checker.SqlResultValidator(
|
|
101
|
-
self._session,
|
|
102
|
-
(
|
|
103
|
-
f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
|
|
104
|
-
f" '{fq_image_repo_name}', '{is_gpu_str}', '{force_rebuild_str}', '', '{external_access_integration}')"
|
|
105
|
-
),
|
|
106
|
-
statement_params=statement_params,
|
|
107
|
-
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
108
|
-
|
|
109
72
|
def deploy_model(
|
|
110
73
|
self,
|
|
111
74
|
*,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
import warnings
|
|
3
4
|
from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union, cast, final
|
|
@@ -24,6 +25,8 @@ if TYPE_CHECKING:
|
|
|
24
25
|
import sklearn.base
|
|
25
26
|
import sklearn.pipeline
|
|
26
27
|
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
27
30
|
|
|
28
31
|
def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline":
|
|
29
32
|
new_steps = []
|
|
@@ -201,13 +204,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
|
201
204
|
explain_target_method = str(explain_target_method) # mypy complains if we don't cast to str here
|
|
202
205
|
|
|
203
206
|
input_signature = handlers_utils.get_input_signature(model_meta, explain_target_method)
|
|
204
|
-
transformed_background_data = _apply_transforms_up_to_last_step(
|
|
205
|
-
model=model,
|
|
206
|
-
data=background_data,
|
|
207
|
-
input_feature_names=[spec.name for spec in input_signature],
|
|
208
|
-
)
|
|
209
207
|
|
|
210
208
|
try:
|
|
209
|
+
transformed_background_data = _apply_transforms_up_to_last_step(
|
|
210
|
+
model=model,
|
|
211
|
+
data=background_data,
|
|
212
|
+
input_feature_names=[spec.name for spec in input_signature],
|
|
213
|
+
)
|
|
211
214
|
model_meta = handlers_utils.add_inferred_explain_method_signature(
|
|
212
215
|
model_meta=model_meta,
|
|
213
216
|
explain_method="explain",
|
|
@@ -217,6 +220,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
|
217
220
|
output_feature_names=transformed_background_data.columns,
|
|
218
221
|
)
|
|
219
222
|
except Exception:
|
|
223
|
+
logger.debug("Explainability is disabled due to an exception.", exc_info=True)
|
|
220
224
|
if kwargs.get("enable_explainability", None):
|
|
221
225
|
# user explicitly enabled explainability, so we should raise the error
|
|
222
226
|
raise ValueError(
|
|
@@ -86,6 +86,9 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
|
86
86
|
df_col_data = utils.series_dropna(df_col_data)
|
|
87
87
|
df_col_dtype = df_col_data.dtype
|
|
88
88
|
|
|
89
|
+
if utils.check_if_series_is_empty(df_col_data):
|
|
90
|
+
continue
|
|
91
|
+
|
|
89
92
|
if df_col_dtype == np.dtype("O"):
|
|
90
93
|
# Check if all objects have the same type
|
|
91
94
|
if not all(isinstance(data_row, type(df_col_data.iloc[0])) for data_row in df_col_data):
|
|
@@ -412,3 +412,7 @@ def infer_dict(name: str, data: dict[str, Any]) -> core.FeatureGroupSpec:
|
|
|
412
412
|
specs.append(core.FeatureSpec(name=key, dtype=core.DataType.from_numpy_type(np.array(value).dtype)))
|
|
413
413
|
|
|
414
414
|
return core.FeatureGroupSpec(name=name, specs=specs)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def check_if_series_is_empty(series: Optional[pd.Series]) -> bool:
|
|
418
|
+
return series is None or series.empty
|
|
@@ -23,12 +23,24 @@ class _TqdmStatusContext:
|
|
|
23
23
|
if state == "complete":
|
|
24
24
|
self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
|
|
25
25
|
self._progress_bar.set_description(label)
|
|
26
|
+
elif state == "error":
|
|
27
|
+
# For error state, use the label as-is and mark with ERROR prefix
|
|
28
|
+
# Don't update progress bar position for errors - leave it where it was
|
|
29
|
+
self._progress_bar.set_description(f"❌ ERROR: {label}")
|
|
26
30
|
else:
|
|
27
|
-
|
|
31
|
+
combined_desc = f"{self._label}: {label}" if label != self._label else self._label
|
|
32
|
+
self._progress_bar.set_description(combined_desc)
|
|
28
33
|
|
|
29
|
-
def increment(self
|
|
34
|
+
def increment(self) -> None:
|
|
30
35
|
"""Increment the progress bar."""
|
|
31
|
-
self._progress_bar.update(
|
|
36
|
+
self._progress_bar.update(1)
|
|
37
|
+
|
|
38
|
+
def complete(self) -> None:
|
|
39
|
+
"""Complete the progress bar to full state."""
|
|
40
|
+
if self._total:
|
|
41
|
+
remaining = self._total - self._progress_bar.n
|
|
42
|
+
if remaining > 0:
|
|
43
|
+
self._progress_bar.update(remaining)
|
|
32
44
|
|
|
33
45
|
|
|
34
46
|
class _StreamlitStatusContext:
|
|
@@ -39,6 +51,7 @@ class _StreamlitStatusContext:
|
|
|
39
51
|
self._streamlit = streamlit_module
|
|
40
52
|
self._total = total
|
|
41
53
|
self._current = 0
|
|
54
|
+
self._current_label = label
|
|
42
55
|
self._progress_bar = None
|
|
43
56
|
|
|
44
57
|
def __enter__(self) -> "_StreamlitStatusContext":
|
|
@@ -49,26 +62,70 @@ class _StreamlitStatusContext:
|
|
|
49
62
|
return self
|
|
50
63
|
|
|
51
64
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
52
|
-
|
|
65
|
+
# Only update to complete if there was no exception
|
|
66
|
+
if exc_type is None:
|
|
67
|
+
self._status_container.update(state="complete")
|
|
53
68
|
|
|
54
69
|
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
|
55
70
|
"""Update the status label."""
|
|
56
|
-
if state
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
71
|
+
if state == "complete" or state == "error":
|
|
72
|
+
# For completion/error, use the message as-is and update main status
|
|
73
|
+
self._status_container.update(label=label, state=state, expanded=expanded)
|
|
74
|
+
self._current_label = label
|
|
75
|
+
|
|
76
|
+
# For error state, update progress bar text but preserve position
|
|
77
|
+
if state == "error" and self._total is not None and self._progress_bar is not None:
|
|
78
|
+
self._progress_bar.progress(
|
|
79
|
+
self._current / self._total if self._total > 0 else 0,
|
|
80
|
+
text=f"ERROR - ({self._current}/{self._total})",
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
combined_label = f"{self._label}: {label}" if label != self._label else self._label
|
|
84
|
+
self._status_container.update(label=combined_label, state=state, expanded=expanded)
|
|
85
|
+
self._current_label = label
|
|
86
|
+
if self._total is not None and self._progress_bar is not None:
|
|
87
|
+
progress_value = self._current / self._total if self._total > 0 else 0
|
|
88
|
+
self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
|
|
89
|
+
|
|
90
|
+
def increment(self) -> None:
|
|
66
91
|
"""Increment the progress."""
|
|
67
92
|
if self._total is not None:
|
|
68
|
-
self._current = min(self._current +
|
|
93
|
+
self._current = min(self._current + 1, self._total)
|
|
69
94
|
if self._progress_bar is not None:
|
|
70
95
|
progress_value = self._current / self._total if self._total > 0 else 0
|
|
71
|
-
self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
|
|
96
|
+
self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
|
|
97
|
+
|
|
98
|
+
def complete(self) -> None:
|
|
99
|
+
"""Complete the progress bar to full state."""
|
|
100
|
+
if self._total is not None:
|
|
101
|
+
self._current = self._total
|
|
102
|
+
if self._progress_bar is not None:
|
|
103
|
+
self._progress_bar.progress(1.0, text=f"({self._current}/{self._total})")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class _NoOpStatusContext:
|
|
107
|
+
"""A no-op context manager for when status updates should be disabled."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, label: str) -> None:
|
|
110
|
+
self._label = label
|
|
111
|
+
|
|
112
|
+
def __enter__(self) -> "_NoOpStatusContext":
|
|
113
|
+
return self
|
|
114
|
+
|
|
115
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
|
119
|
+
"""No-op update method."""
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
def increment(self) -> None:
|
|
123
|
+
"""No-op increment method."""
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def complete(self) -> None:
|
|
127
|
+
"""No-op complete method."""
|
|
128
|
+
pass
|
|
72
129
|
|
|
73
130
|
|
|
74
131
|
class ModelEventHandler:
|
|
@@ -99,7 +156,15 @@ class ModelEventHandler:
|
|
|
99
156
|
else:
|
|
100
157
|
self._tqdm.tqdm.write(message)
|
|
101
158
|
|
|
102
|
-
def status(
|
|
159
|
+
def status(
|
|
160
|
+
self,
|
|
161
|
+
label: str,
|
|
162
|
+
*,
|
|
163
|
+
state: str = "running",
|
|
164
|
+
expanded: bool = True,
|
|
165
|
+
total: Optional[int] = None,
|
|
166
|
+
block: bool = True,
|
|
167
|
+
) -> Any:
|
|
103
168
|
"""Context manager that provides status updates with optional enhanced display capabilities.
|
|
104
169
|
|
|
105
170
|
Args:
|
|
@@ -107,10 +172,14 @@ class ModelEventHandler:
|
|
|
107
172
|
state: The initial state ("running", "complete", "error")
|
|
108
173
|
expanded: Whether to show expanded view (streamlit only)
|
|
109
174
|
total: Total number of steps for progress tracking (optional)
|
|
175
|
+
block: Whether to show progress updates (no-op if False)
|
|
110
176
|
|
|
111
177
|
Returns:
|
|
112
|
-
Status context (Streamlit or
|
|
178
|
+
Status context (Streamlit, Tqdm, or NoOp based on availability and block parameter)
|
|
113
179
|
"""
|
|
180
|
+
if not block:
|
|
181
|
+
return _NoOpStatusContext(label)
|
|
182
|
+
|
|
114
183
|
if self._streamlit is not None:
|
|
115
184
|
return _StreamlitStatusContext(label, self._streamlit, total)
|
|
116
185
|
else:
|
|
@@ -272,6 +272,8 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
|
272
272
|
),
|
|
273
273
|
)
|
|
274
274
|
else:
|
|
275
|
+
if utils.check_if_series_is_empty(data_col):
|
|
276
|
+
continue
|
|
275
277
|
if isinstance(data_col.iloc[0], list):
|
|
276
278
|
if not ft_shape:
|
|
277
279
|
raise snowml_exceptions.SnowflakeMLException(
|