snowflake-ml-python 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/ml/_internal/utils/mixins.py +6 -4
  2. snowflake/ml/_internal/utils/service_logger.py +118 -4
  3. snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
  4. snowflake/ml/data/data_connector.py +4 -34
  5. snowflake/ml/dataset/dataset.py +1 -1
  6. snowflake/ml/dataset/dataset_reader.py +2 -8
  7. snowflake/ml/experiment/__init__.py +3 -0
  8. snowflake/ml/experiment/callback/lightgbm.py +55 -0
  9. snowflake/ml/experiment/callback/xgboost.py +63 -0
  10. snowflake/ml/experiment/utils.py +14 -0
  11. snowflake/ml/jobs/_utils/constants.py +15 -4
  12. snowflake/ml/jobs/_utils/payload_utils.py +159 -52
  13. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  14. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +126 -23
  15. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  16. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  17. snowflake/ml/jobs/_utils/types.py +64 -4
  18. snowflake/ml/jobs/job.py +22 -6
  19. snowflake/ml/jobs/manager.py +5 -3
  20. snowflake/ml/model/_client/model/model_version_impl.py +56 -48
  21. snowflake/ml/model/_client/ops/service_ops.py +194 -14
  22. snowflake/ml/model/_client/sql/service.py +1 -38
  23. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  24. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
  25. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  26. snowflake/ml/model/_signatures/utils.py +4 -0
  27. snowflake/ml/model/event_handler.py +87 -18
  28. snowflake/ml/model/model_signature.py +2 -0
  29. snowflake/ml/model/models/huggingface_pipeline.py +71 -49
  30. snowflake/ml/model/type_hints.py +26 -1
  31. snowflake/ml/registry/_manager/model_manager.py +30 -35
  32. snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
  33. snowflake/ml/registry/registry.py +0 -19
  34. snowflake/ml/version.py +1 -1
  35. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +542 -491
  36. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +39 -34
  37. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
  38. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
  39. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
@@ -96,11 +96,13 @@ class ServiceLogMetadata:
96
96
  msg: str,
97
97
  is_model_build_service_done: bool,
98
98
  is_model_logger_service_done: bool,
99
+ operation_id: str,
99
100
  propagate: bool = False,
100
101
  ) -> None:
101
102
  to_service_logger = service_logger.get_logger(
102
103
  f"{to_service.display_service_name}-{to_service.instance_id}",
103
104
  to_service.log_color,
105
+ operation_id=operation_id,
104
106
  )
105
107
  to_service_logger.propagate = propagate
106
108
  self.service_logger = to_service_logger
@@ -178,9 +180,7 @@ class ServiceOperator:
178
180
  service_name: sql_identifier.SqlIdentifier,
179
181
  image_build_compute_pool_name: sql_identifier.SqlIdentifier,
180
182
  service_compute_pool_name: sql_identifier.SqlIdentifier,
181
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
182
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
183
- image_repo_name: sql_identifier.SqlIdentifier,
183
+ image_repo: str,
184
184
  ingress_enabled: bool,
185
185
  max_instances: int,
186
186
  cpu_requests: Optional[str],
@@ -191,11 +191,15 @@ class ServiceOperator:
191
191
  force_rebuild: bool,
192
192
  build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
193
193
  block: bool,
194
+ progress_status: type_hints.ProgressStatus,
194
195
  statement_params: Optional[dict[str, Any]] = None,
195
196
  # hf model
196
197
  hf_model_args: Optional[HFModelArgs] = None,
197
198
  ) -> Union[str, async_job.AsyncJob]:
198
199
 
200
+ # Generate operation ID for this deployment
201
+ operation_id = service_logger.get_operation_id()
202
+
199
203
  # Fall back to the registry's database and schema if not provided
200
204
  database_name = database_name or self._database_name
201
205
  schema_name = schema_name or self._schema_name
@@ -204,8 +208,17 @@ class ServiceOperator:
204
208
  service_database_name = service_database_name or database_name or self._database_name
205
209
  service_schema_name = service_schema_name or schema_name or self._schema_name
206
210
 
211
+ # Parse image repo
212
+ image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
213
+ image_repo
214
+ )
207
215
  image_repo_database_name = image_repo_database_name or database_name or self._database_name
208
216
  image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
217
+
218
+ # Step 1: Preparing deployment artifacts
219
+ progress_status.update("preparing deployment artifacts...")
220
+ progress_status.increment()
221
+
209
222
  if self._workspace:
210
223
  stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
211
224
  else:
@@ -254,6 +267,11 @@ class ServiceOperator:
254
267
  **(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
255
268
  )
256
269
  spec_yaml_str_or_path = self._model_deployment_spec.save()
270
+
271
+ # Step 2: Uploading deployment artifacts
272
+ progress_status.update("uploading deployment artifacts...")
273
+ progress_status.increment()
274
+
257
275
  if self._workspace:
258
276
  assert stage_path is not None
259
277
  file_utils.upload_directory_to_stage(
@@ -276,6 +294,10 @@ class ServiceOperator:
276
294
  statement_params=statement_params,
277
295
  )
278
296
 
297
+ # Step 3: Initiating model deployment
298
+ progress_status.update("initiating model deployment...")
299
+ progress_status.increment()
300
+
279
301
  # deploy the model service
280
302
  query_id, async_job = self._service_client.deploy_model(
281
303
  stage_path=stage_path if self._workspace else None,
@@ -327,17 +349,68 @@ class ServiceOperator:
327
349
  model_inference_service=model_inference_service,
328
350
  model_inference_service_exists=model_inference_service_exists,
329
351
  force_rebuild=force_rebuild,
352
+ operation_id=operation_id,
330
353
  statement_params=statement_params,
331
354
  )
332
355
 
333
356
  if block:
334
- log_thread.join()
357
+ try:
358
+ # Step 4: Starting model build: waits for build to start
359
+ progress_status.update("starting model image build...")
360
+ progress_status.increment()
361
+
362
+ # Poll for model build to start if not using existing service
363
+ if not model_inference_service_exists:
364
+ self._wait_for_service_status(
365
+ model_build_service_name,
366
+ service_sql.ServiceStatus.RUNNING,
367
+ service_database_name,
368
+ service_schema_name,
369
+ async_job,
370
+ statement_params,
371
+ )
335
372
 
336
- res = cast(str, cast(list[row.Row], async_job.result())[0][0])
337
- module_logger.info(f"Inference service {service_name} deployment complete: {res}")
338
- return res
339
- else:
340
- return async_job
373
+ # Step 5: Building model image
374
+ progress_status.update("building model image...")
375
+ progress_status.increment()
376
+
377
+ # Poll for model build completion
378
+ if not model_inference_service_exists:
379
+ self._wait_for_service_status(
380
+ model_build_service_name,
381
+ service_sql.ServiceStatus.DONE,
382
+ service_database_name,
383
+ service_schema_name,
384
+ async_job,
385
+ statement_params,
386
+ )
387
+
388
+ # Step 6: Deploying model service (push complete, starting inference service)
389
+ progress_status.update("deploying model service...")
390
+ progress_status.increment()
391
+
392
+ log_thread.join()
393
+
394
+ res = cast(str, cast(list[row.Row], async_job.result())[0][0])
395
+ module_logger.info(f"Inference service {service_name} deployment complete: {res}")
396
+ return res
397
+
398
+ except RuntimeError as e:
399
+ # Handle service creation/deployment failures
400
+ error_msg = f"Model service deployment failed: {str(e)}"
401
+ module_logger.error(error_msg)
402
+
403
+ # Update progress status to show failure
404
+ progress_status.update(error_msg, state="error")
405
+
406
+ # Stop the log thread if it's running
407
+ if "log_thread" in locals() and log_thread.is_alive():
408
+ log_thread.join(timeout=5) # Give it a few seconds to finish gracefully
409
+
410
+ # Re-raise the exception to propagate the error
411
+ raise RuntimeError(error_msg) from e
412
+
413
+ return async_job
341
414
 
342
415
  def _start_service_log_streaming(
343
416
  self,
@@ -347,6 +420,7 @@ class ServiceOperator:
347
420
  model_inference_service: ServiceLogInfo,
348
421
  model_inference_service_exists: bool,
349
422
  force_rebuild: bool,
423
+ operation_id: str,
350
424
  statement_params: Optional[dict[str, Any]] = None,
351
425
  ) -> threading.Thread:
352
426
  """Start the service log streaming in a separate thread."""
@@ -360,6 +434,7 @@ class ServiceOperator:
360
434
  model_inference_service,
361
435
  model_inference_service_exists,
362
436
  force_rebuild,
437
+ operation_id,
363
438
  statement_params,
364
439
  ),
365
440
  )
@@ -372,6 +447,7 @@ class ServiceOperator:
372
447
  service_log_meta: ServiceLogMetadata,
373
448
  model_build_service: ServiceLogInfo,
374
449
  model_inference_service: ServiceLogInfo,
450
+ operation_id: str,
375
451
  statement_params: Optional[dict[str, Any]] = None,
376
452
  ) -> None:
377
453
  """Helper function to fetch logs and update the service log metadata if needed.
@@ -386,6 +462,7 @@ class ServiceOperator:
386
462
  service_log_meta: The ServiceLogMetadata holds the state of the service log metadata.
387
463
  model_build_service: The ServiceLogInfo for the model build service.
388
464
  model_inference_service: The ServiceLogInfo for the model inference service.
465
+ operation_id: The operation ID for the service, e.g. "model_deploy_a1b2c3d4_1703875200"
389
466
  statement_params: The statement parameters to use for the service client.
390
467
  """
391
468
 
@@ -415,6 +492,7 @@ class ServiceOperator:
415
492
  "Model build is not rebuilding the inference image, but using a previously built image.",
416
493
  is_model_build_service_done=True,
417
494
  is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
495
+ operation_id=operation_id,
418
496
  )
419
497
 
420
498
  try:
@@ -488,6 +566,7 @@ class ServiceOperator:
488
566
  f"Model Logger service {service.display_service_name} complete.",
489
567
  is_model_build_service_done=False,
490
568
  is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
569
+ operation_id=operation_id,
491
570
  )
492
571
  # check if model build service is done
493
572
  # and transition the service log metadata to the model inference service
@@ -497,6 +576,7 @@ class ServiceOperator:
497
576
  f"Image build service {service.display_service_name} complete.",
498
577
  is_model_build_service_done=True,
499
578
  is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
579
+ operation_id=operation_id,
500
580
  )
501
581
  else:
502
582
  module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
@@ -509,6 +589,7 @@ class ServiceOperator:
509
589
  model_inference_service: ServiceLogInfo,
510
590
  model_inference_service_exists: bool,
511
591
  force_rebuild: bool,
592
+ operation_id: str,
512
593
  statement_params: Optional[dict[str, Any]] = None,
513
594
  ) -> None:
514
595
  """Stream service logs while the async job is running."""
@@ -516,14 +597,14 @@ class ServiceOperator:
516
597
  model_build_service_logger = service_logger.get_logger( # BuildJobName
517
598
  model_build_service.display_service_name,
518
599
  model_build_service.log_color,
600
+ operation_id=operation_id,
519
601
  )
520
- model_build_service_logger.propagate = False
521
602
  if model_logger_service:
522
603
  model_logger_service_logger = service_logger.get_logger( # ModelLoggerName
523
604
  model_logger_service.display_service_name,
524
605
  model_logger_service.log_color,
606
+ operation_id=operation_id,
525
607
  )
526
- model_logger_service_logger.propagate = False
527
608
 
528
609
  service_log_meta = ServiceLogMetadata(
529
610
  service_logger=model_logger_service_logger,
@@ -557,6 +638,7 @@ class ServiceOperator:
557
638
  force_rebuild=force_rebuild,
558
639
  model_build_service=model_build_service,
559
640
  model_inference_service=model_inference_service,
641
+ operation_id=operation_id,
560
642
  statement_params=statement_params,
561
643
  )
562
644
  except Exception as ex:
@@ -564,6 +646,7 @@ class ServiceOperator:
564
646
  is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException)
565
647
  contains_msg = any(msg in str(ex) for msg in ["Pending scheduling", "Waiting to start"])
566
648
  matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None
649
+
567
650
  if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)):
568
651
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
569
652
  time.sleep(5)
@@ -603,6 +686,101 @@ class ServiceOperator:
603
686
  except Exception as ex:
604
687
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
605
688
 
689
+ def _wait_for_service_status(
690
+ self,
691
+ service_name: sql_identifier.SqlIdentifier,
692
+ target_status: service_sql.ServiceStatus,
693
+ service_database_name: Optional[sql_identifier.SqlIdentifier],
694
+ service_schema_name: Optional[sql_identifier.SqlIdentifier],
695
+ async_job: snowpark.AsyncJob,
696
+ statement_params: Optional[dict[str, Any]] = None,
697
+ timeout_minutes: int = 30,
698
+ ) -> None:
699
+ """Wait for service to reach the specified status while monitoring async job for failures.
700
+
701
+ Args:
702
+ service_name: The service to monitor
703
+ target_status: The target status to wait for
704
+ service_database_name: Database containing the service
705
+ service_schema_name: Schema containing the service
706
+ async_job: The async job to monitor for completion/failure
707
+ statement_params: SQL statement parameters
708
+ timeout_minutes: Maximum time to wait before timing out
709
+
710
+ Raises:
711
+ RuntimeError: If service fails, times out, or enters an error state
712
+ """
713
+ start_time = time.time()
714
+ timeout_seconds = timeout_minutes * 60
715
+ service_seen_before = False
716
+
717
+ while True:
718
+ # Check if async job has failed (but don't return on success - we need specific service status)
719
+ if async_job.is_done():
720
+ try:
721
+ async_job.result()
722
+ # Async job completed successfully, but we're waiting for a specific service status
723
+ # This might mean the service completed and was cleaned up
724
+ module_logger.debug(
725
+ f"Async job completed but we're still waiting for {service_name} to reach {target_status.value}"
726
+ )
727
+ except Exception as e:
728
+ raise RuntimeError(f"Service deployment failed: {e}")
729
+
730
+ try:
731
+ statuses = self._service_client.get_service_container_statuses(
732
+ database_name=service_database_name,
733
+ schema_name=service_schema_name,
734
+ service_name=service_name,
735
+ include_message=True,
736
+ statement_params=statement_params,
737
+ )
738
+
739
+ if statuses:
740
+ service_seen_before = True
741
+ current_status = statuses[0].service_status
742
+
743
+ # Check if we've reached the target status
744
+ if current_status == target_status:
745
+ return
746
+
747
+ # Check for failure states
748
+ if current_status in [service_sql.ServiceStatus.FAILED, service_sql.ServiceStatus.INTERNAL_ERROR]:
749
+ error_msg = f"Service {service_name} failed with status {current_status.value}"
750
+ if statuses[0].message:
751
+ error_msg += f": {statuses[0].message}"
752
+ raise RuntimeError(error_msg)
753
+
754
+ except exceptions.SnowparkSQLException as e:
755
+ # Service might not exist yet - this is expected during initial deployment
756
+ if "does not exist" in str(e) or "002003" in str(e):
757
+ # If we're waiting for DONE status and we've seen the service before,
758
+ # it likely completed and was cleaned up
759
+ if target_status == service_sql.ServiceStatus.DONE and service_seen_before:
760
+ module_logger.debug(
761
+ f"Service {service_name} disappeared after being seen, "
762
+ f"assuming it reached {target_status.value} and was cleaned up"
763
+ )
764
+ return
765
+
766
+ module_logger.debug(f"Service {service_name} not found yet, continuing to wait...")
767
+ else:
768
+ # Re-raise unexpected SQL exceptions
769
+ raise RuntimeError(f"Error checking service status: {e}")
770
+ except Exception as e:
771
+ # Re-raise unexpected exceptions instead of masking them
772
+ raise RuntimeError(f"Unexpected error while waiting for service status: {e}")
773
+
774
+ # Check timeout
775
+ elapsed_time = time.time() - start_time
776
+ if elapsed_time > timeout_seconds:
777
+ raise RuntimeError(
778
+ f"Timeout waiting for service {service_name} to reach status {target_status.value} "
779
+ f"after {timeout_minutes} minutes"
780
+ )
781
+
782
+ time.sleep(2) # Poll every 2 seconds
783
+
606
784
  @staticmethod
607
785
  def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
608
786
  """Get the service ID through the server-side logic."""
@@ -660,9 +838,7 @@ class ServiceOperator:
660
838
  job_name: sql_identifier.SqlIdentifier,
661
839
  compute_pool_name: sql_identifier.SqlIdentifier,
662
840
  warehouse_name: sql_identifier.SqlIdentifier,
663
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
664
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
665
- image_repo_name: sql_identifier.SqlIdentifier,
841
+ image_repo: str,
666
842
  output_table_database_name: Optional[sql_identifier.SqlIdentifier],
667
843
  output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
668
844
  output_table_name: sql_identifier.SqlIdentifier,
@@ -683,6 +859,10 @@ class ServiceOperator:
683
859
  job_database_name = job_database_name or database_name or self._database_name
684
860
  job_schema_name = job_schema_name or schema_name or self._schema_name
685
861
 
862
+ # Parse image repo
863
+ image_repo_database_name, image_repo_schema_name, image_repo_name = sql_identifier.parse_fully_qualified_name(
864
+ image_repo
865
+ )
686
866
  image_repo_database_name = image_repo_database_name or database_name or self._database_name
687
867
  image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
688
868
 
@@ -2,7 +2,7 @@ import dataclasses
2
2
  import enum
3
3
  import logging
4
4
  import textwrap
5
- from typing import Any, Optional, Union
5
+ from typing import Any, Optional
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal.utils import (
@@ -69,43 +69,6 @@ class ServiceSQLClient(_base._BaseSQLClient):
69
69
  CONTAINER_STATUS = "status"
70
70
  MESSAGE = "message"
71
71
 
72
- def build_model_container(
73
- self,
74
- *,
75
- database_name: Optional[sql_identifier.SqlIdentifier],
76
- schema_name: Optional[sql_identifier.SqlIdentifier],
77
- model_name: sql_identifier.SqlIdentifier,
78
- version_name: sql_identifier.SqlIdentifier,
79
- compute_pool_name: sql_identifier.SqlIdentifier,
80
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
81
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
82
- image_repo_name: sql_identifier.SqlIdentifier,
83
- gpu: Optional[Union[str, int]],
84
- force_rebuild: bool,
85
- external_access_integration: sql_identifier.SqlIdentifier,
86
- statement_params: Optional[dict[str, Any]] = None,
87
- ) -> None:
88
- actual_image_repo_database = image_repo_database_name or self._database_name
89
- actual_image_repo_schema = image_repo_schema_name or self._schema_name
90
- actual_model_database = database_name or self._database_name
91
- actual_model_schema = schema_name or self._schema_name
92
- fq_model_name = self.fully_qualified_object_name(actual_model_database, actual_model_schema, model_name)
93
- fq_image_repo_name = identifier.get_schema_level_object_identifier(
94
- actual_image_repo_database.identifier(),
95
- actual_image_repo_schema.identifier(),
96
- image_repo_name.identifier(),
97
- )
98
- is_gpu_str = "TRUE" if gpu else "FALSE"
99
- force_rebuild_str = "TRUE" if force_rebuild else "FALSE"
100
- query_result_checker.SqlResultValidator(
101
- self._session,
102
- (
103
- f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
104
- f" '{fq_image_repo_name}', '{is_gpu_str}', '{force_rebuild_str}', '', '{external_access_integration}')"
105
- ),
106
- statement_params=statement_params,
107
- ).has_dimensions(expected_rows=1, expected_cols=1).validate()
108
-
109
72
  def deploy_model(
110
73
  self,
111
74
  *,
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  import os
2
3
  import warnings
3
4
  from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union, cast, final
@@ -24,6 +25,8 @@ if TYPE_CHECKING:
24
25
  import sklearn.base
25
26
  import sklearn.pipeline
26
27
 
28
+ logger = logging.getLogger(__name__)
29
+
27
30
 
28
31
  def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "sklearn.pipeline.Pipeline":
29
32
  new_steps = []
@@ -201,13 +204,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
201
204
  explain_target_method = str(explain_target_method) # mypy complains if we don't cast to str here
202
205
 
203
206
  input_signature = handlers_utils.get_input_signature(model_meta, explain_target_method)
204
- transformed_background_data = _apply_transforms_up_to_last_step(
205
- model=model,
206
- data=background_data,
207
- input_feature_names=[spec.name for spec in input_signature],
208
- )
209
207
 
210
208
  try:
209
+ transformed_background_data = _apply_transforms_up_to_last_step(
210
+ model=model,
211
+ data=background_data,
212
+ input_feature_names=[spec.name for spec in input_signature],
213
+ )
211
214
  model_meta = handlers_utils.add_inferred_explain_method_signature(
212
215
  model_meta=model_meta,
213
216
  explain_method="explain",
@@ -217,6 +220,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
217
220
  output_feature_names=transformed_background_data.columns,
218
221
  )
219
222
  except Exception:
223
+ logger.debug("Explainability is disabled due to an exception.", exc_info=True)
220
224
  if kwargs.get("enable_explainability", None):
221
225
  # user explicitly enabled explainability, so we should raise the error
222
226
  raise ValueError(
@@ -13,6 +13,7 @@ REQUIREMENTS = [
13
13
  "numpy>=1.23,<3",
14
14
  "packaging>=20.9,<25",
15
15
  "pandas>=2.1.4,<3",
16
+ "platformdirs<5",
16
17
  "pyarrow",
17
18
  "pydantic>=2.8.2, <3",
18
19
  "pyjwt>=2.0.0, <3",
@@ -86,6 +86,9 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
86
86
  df_col_data = utils.series_dropna(df_col_data)
87
87
  df_col_dtype = df_col_data.dtype
88
88
 
89
+ if utils.check_if_series_is_empty(df_col_data):
90
+ continue
91
+
89
92
  if df_col_dtype == np.dtype("O"):
90
93
  # Check if all objects have the same type
91
94
  if not all(isinstance(data_row, type(df_col_data.iloc[0])) for data_row in df_col_data):
@@ -412,3 +412,7 @@ def infer_dict(name: str, data: dict[str, Any]) -> core.FeatureGroupSpec:
412
412
  specs.append(core.FeatureSpec(name=key, dtype=core.DataType.from_numpy_type(np.array(value).dtype)))
413
413
 
414
414
  return core.FeatureGroupSpec(name=name, specs=specs)
415
+
416
+
417
+ def check_if_series_is_empty(series: Optional[pd.Series]) -> bool:
418
+ return series is None or series.empty
@@ -23,12 +23,24 @@ class _TqdmStatusContext:
23
23
  if state == "complete":
24
24
  self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
25
25
  self._progress_bar.set_description(label)
26
+ elif state == "error":
27
+ # For error state, use the label as-is and mark with ERROR prefix
28
+ # Don't update progress bar position for errors - leave it where it was
29
+ self._progress_bar.set_description(f"❌ ERROR: {label}")
26
30
  else:
27
- self._progress_bar.set_description(f"{self._label}: {label}")
31
+ combined_desc = f"{self._label}: {label}" if label != self._label else self._label
32
+ self._progress_bar.set_description(combined_desc)
28
33
 
29
- def increment(self, n: int = 1) -> None:
34
+ def increment(self) -> None:
30
35
  """Increment the progress bar."""
31
- self._progress_bar.update(n)
36
+ self._progress_bar.update(1)
37
+
38
+ def complete(self) -> None:
39
+ """Complete the progress bar to full state."""
40
+ if self._total:
41
+ remaining = self._total - self._progress_bar.n
42
+ if remaining > 0:
43
+ self._progress_bar.update(remaining)
32
44
 
33
45
 
34
46
  class _StreamlitStatusContext:
@@ -39,6 +51,7 @@ class _StreamlitStatusContext:
39
51
  self._streamlit = streamlit_module
40
52
  self._total = total
41
53
  self._current = 0
54
+ self._current_label = label
42
55
  self._progress_bar = None
43
56
 
44
57
  def __enter__(self) -> "_StreamlitStatusContext":
@@ -49,26 +62,70 @@ class _StreamlitStatusContext:
49
62
  return self
50
63
 
51
64
  def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
52
- self._status_container.update(state="complete")
65
+ # Only update to complete if there was no exception
66
+ if exc_type is None:
67
+ self._status_container.update(state="complete")
53
68
 
54
69
  def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
55
70
  """Update the status label."""
56
- if state != "complete":
57
- label = f"{self._label}: {label}"
58
- self._status_container.update(label=label, state=state, expanded=expanded)
59
- if self._progress_bar is not None:
60
- self._progress_bar.progress(
61
- self._current / self._total if self._total > 0 else 0,
62
- text=f"{label} - {self._current}/{self._total}",
63
- )
64
-
65
- def increment(self, n: int = 1) -> None:
71
+ if state == "complete" or state == "error":
72
+ # For completion/error, use the message as-is and update main status
73
+ self._status_container.update(label=label, state=state, expanded=expanded)
74
+ self._current_label = label
75
+
76
+ # For error state, update progress bar text but preserve position
77
+ if state == "error" and self._total is not None and self._progress_bar is not None:
78
+ self._progress_bar.progress(
79
+ self._current / self._total if self._total > 0 else 0,
80
+ text=f"ERROR - ({self._current}/{self._total})",
81
+ )
82
+ else:
83
+ combined_label = f"{self._label}: {label}" if label != self._label else self._label
84
+ self._status_container.update(label=combined_label, state=state, expanded=expanded)
85
+ self._current_label = label
86
+ if self._total is not None and self._progress_bar is not None:
87
+ progress_value = self._current / self._total if self._total > 0 else 0
88
+ self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
89
+
90
+ def increment(self) -> None:
66
91
  """Increment the progress."""
67
92
  if self._total is not None:
68
- self._current = min(self._current + n, self._total)
93
+ self._current = min(self._current + 1, self._total)
69
94
  if self._progress_bar is not None:
70
95
  progress_value = self._current / self._total if self._total > 0 else 0
71
- self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
96
+ self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
97
+
98
+ def complete(self) -> None:
99
+ """Complete the progress bar to full state."""
100
+ if self._total is not None:
101
+ self._current = self._total
102
+ if self._progress_bar is not None:
103
+ self._progress_bar.progress(1.0, text=f"({self._current}/{self._total})")
104
+
105
+
106
+ class _NoOpStatusContext:
107
+ """A no-op context manager for when status updates should be disabled."""
108
+
109
+ def __init__(self, label: str) -> None:
110
+ self._label = label
111
+
112
+ def __enter__(self) -> "_NoOpStatusContext":
113
+ return self
114
+
115
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
116
+ pass
117
+
118
+ def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
119
+ """No-op update method."""
120
+ pass
121
+
122
+ def increment(self) -> None:
123
+ """No-op increment method."""
124
+ pass
125
+
126
+ def complete(self) -> None:
127
+ """No-op complete method."""
128
+ pass
72
129
 
73
130
 
74
131
  class ModelEventHandler:
@@ -99,7 +156,15 @@ class ModelEventHandler:
99
156
  else:
100
157
  self._tqdm.tqdm.write(message)
101
158
 
102
- def status(self, label: str, *, state: str = "running", expanded: bool = True, total: Optional[int] = None) -> Any:
159
+ def status(
160
+ self,
161
+ label: str,
162
+ *,
163
+ state: str = "running",
164
+ expanded: bool = True,
165
+ total: Optional[int] = None,
166
+ block: bool = True,
167
+ ) -> Any:
103
168
  """Context manager that provides status updates with optional enhanced display capabilities.
104
169
 
105
170
  Args:
@@ -107,10 +172,14 @@ class ModelEventHandler:
107
172
  state: The initial state ("running", "complete", "error")
108
173
  expanded: Whether to show expanded view (streamlit only)
109
174
  total: Total number of steps for progress tracking (optional)
175
+ block: Whether to show progress updates (no-op if False)
110
176
 
111
177
  Returns:
112
- Status context (Streamlit or Tqdm)
178
+ Status context (Streamlit, Tqdm, or NoOp based on availability and block parameter)
113
179
  """
180
+ if not block:
181
+ return _NoOpStatusContext(label)
182
+
114
183
  if self._streamlit is not None:
115
184
  return _StreamlitStatusContext(label, self._streamlit, total)
116
185
  else:
@@ -272,6 +272,8 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
272
272
  ),
273
273
  )
274
274
  else:
275
+ if utils.check_if_series_is_empty(data_col):
276
+ continue
275
277
  if isinstance(data_col.iloc[0], list):
276
278
  if not ft_shape:
277
279
  raise snowml_exceptions.SnowflakeMLException(