snowflake-ml-python 1.9.2__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,15 @@ from typing import Optional
9
9
 
10
10
  import platformdirs
11
11
 
12
+ # Module-level logger for operational messages that should appear on console
13
+ stdout_handler = logging.StreamHandler(sys.stdout)
14
+ stdout_handler.setFormatter(logging.Formatter("%(message)s"))
15
+
16
+ console_logger = logging.getLogger(__name__)
17
+ console_logger.addHandler(stdout_handler)
18
+ console_logger.setLevel(logging.INFO)
19
+ console_logger.propagate = False
20
+
12
21
 
13
22
  class LogColor(enum.Enum):
14
23
  GREY = "\x1b[38;20m"
@@ -109,42 +118,36 @@ def _get_or_create_parent_logger(operation_id: str) -> logging.Logger:
109
118
  """Get or create a parent logger with FileHandler for the operation."""
110
119
  parent_logger_name = f"snowflake_ml_operation_{operation_id}"
111
120
  parent_logger = logging.getLogger(parent_logger_name)
121
+ parent_logger.setLevel(logging.DEBUG)
122
+ parent_logger.propagate = False
112
123
 
113
- # Only add handler if it doesn't exist yet
114
124
  if not parent_logger.handlers:
115
125
  log_file_path = _get_log_file_path(operation_id)
116
126
 
117
127
  if log_file_path:
118
- # Successfully found a writable location
119
128
  try:
120
129
  file_handler = logging.FileHandler(log_file_path)
121
130
  file_handler.setFormatter(logging.Formatter("%(name)s [%(asctime)s] [%(levelname)s] %(message)s"))
122
131
  parent_logger.addHandler(file_handler)
123
- parent_logger.setLevel(logging.DEBUG)
124
- parent_logger.propagate = False # Don't propagate to root logger
125
132
 
126
- # Log the file location
127
- parent_logger.warning(f"Operation logs saved to: {log_file_path}")
133
+ console_logger.info(f"create_service logs saved to: {log_file_path}")
128
134
  except OSError as e:
129
- # Even though we found a path, file creation failed
130
- # Fall back to console-only logging
131
- parent_logger.setLevel(logging.DEBUG)
132
- parent_logger.propagate = False
133
- parent_logger.warning(f"Could not create log file at {log_file_path}: {e}. Using console-only logging.")
135
+ console_logger.warning(f"Could not create log file at {log_file_path}: {e}.")
134
136
  else:
135
137
  # No writable location found, use console-only logging
136
- parent_logger.setLevel(logging.DEBUG)
137
- parent_logger.propagate = False
138
- parent_logger.warning("Filesystem appears to be readonly. Using console-only logging.")
138
+ console_logger.warning("No writable location found for create_service log file.")
139
+
140
+ if logging.getLogger().level > logging.INFO:
141
+ console_logger.info(
142
+ "To see logs in console, set log level to INFO: logging.getLogger().setLevel(logging.INFO)"
143
+ )
139
144
 
140
145
  return parent_logger
141
146
 
142
147
 
143
148
  def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[str] = None) -> logging.Logger:
144
149
  logger = logging.getLogger(logger_name)
145
- handler = logging.StreamHandler(sys.stdout)
146
- handler.setFormatter(CustomFormatter(info_color))
147
- logger.addHandler(handler)
150
+ root_logger = logging.getLogger()
148
151
 
149
152
  # If operation_id provided, set up parent logger with file handler
150
153
  if operation_id:
@@ -152,6 +155,17 @@ def get_logger(logger_name: str, info_color: LogColor, operation_id: Optional[st
152
155
  logger.parent = parent_logger
153
156
  logger.propagate = True
154
157
 
158
+ if root_logger.level <= logging.INFO:
159
+ handler = logging.StreamHandler(sys.stdout)
160
+ handler.setFormatter(CustomFormatter(info_color))
161
+ logger.addHandler(handler)
162
+ else:
163
+ # No operation_id - add console handler only if user wants verbose logging
164
+ if root_logger.level <= logging.INFO and not logger.handlers:
165
+ handler = logging.StreamHandler(sys.stdout)
166
+ handler.setFormatter(CustomFormatter(info_color))
167
+ logger.addHandler(handler)
168
+
155
169
  return logger
156
170
 
157
171
 
@@ -0,0 +1,55 @@
1
+ from typing import TYPE_CHECKING, Optional
2
+ from warnings import warn
3
+
4
+ import lightgbm as lgb
5
+
6
+ if TYPE_CHECKING:
7
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
8
+ from snowflake.ml.model.model_signature import ModelSignature
9
+
10
+
11
+ class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
12
+ def __init__(
13
+ self,
14
+ experiment_tracking: "ExperimentTracking",
15
+ log_model: bool = True,
16
+ log_metrics: bool = True,
17
+ log_params: bool = True,
18
+ model_name: Optional[str] = None,
19
+ model_signature: Optional["ModelSignature"] = None,
20
+ ) -> None:
21
+ self._experiment_tracking = experiment_tracking
22
+ self.log_model = log_model
23
+ self.log_metrics = log_metrics
24
+ self.log_params = log_params
25
+ self.model_name = model_name
26
+ self.model_signature = model_signature
27
+
28
+ super().__init__(eval_result={})
29
+
30
+ def __call__(self, env: lgb.callback.CallbackEnv) -> None:
31
+ if self.log_params:
32
+ if env.iteration == env.begin_iteration: # Log params only at the first iteration
33
+ self._experiment_tracking.log_params(env.params)
34
+
35
+ if self.log_metrics:
36
+ super().__call__(env)
37
+ for dataset_name, metrics in self.eval_result.items():
38
+ for metric_name, log in metrics.items():
39
+ metric_key = dataset_name + ":" + metric_name
40
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=env.iteration)
41
+
42
+ if self.log_model:
43
+ if env.iteration == env.end_iteration - 1: # Log model only at the last iteration
44
+ if self.model_signature:
45
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
46
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
47
+ model=env.model,
48
+ model_name=model_name,
49
+ signatures={"predict": self.model_signature},
50
+ )
51
+ else:
52
+ warn(
53
+ "Model will not be logged because model signature is missing. To autolog the model, "
54
+ "please specify `model_signature` when constructing SnowflakeLightgbmCallback."
55
+ )
@@ -0,0 +1,63 @@
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Optional
3
+ from warnings import warn
4
+
5
+ import xgboost as xgb
6
+
7
+ from snowflake.ml.experiment import utils
8
+
9
+ if TYPE_CHECKING:
10
+ from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
11
+ from snowflake.ml.model.model_signature import ModelSignature
12
+
13
+
14
+ class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
15
+ def __init__(
16
+ self,
17
+ experiment_tracking: "ExperimentTracking",
18
+ log_model: bool = True,
19
+ log_metrics: bool = True,
20
+ log_params: bool = True,
21
+ model_name: Optional[str] = None,
22
+ model_signature: Optional["ModelSignature"] = None,
23
+ ) -> None:
24
+ self._experiment_tracking = experiment_tracking
25
+ self.log_model = log_model
26
+ self.log_metrics = log_metrics
27
+ self.log_params = log_params
28
+ self.model_name = model_name
29
+ self.model_signature = model_signature
30
+
31
+ def before_training(self, model: xgb.Booster) -> xgb.Booster:
32
+ if self.log_params:
33
+ params = json.loads(model.save_config())
34
+ self._experiment_tracking.log_params(utils.flatten_nested_params(params))
35
+
36
+ return model
37
+
38
+ def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool:
39
+ if self.log_metrics:
40
+ for dataset_name, metrics in evals_log.items():
41
+ for metric_name, log in metrics.items():
42
+ metric_key = dataset_name + ":" + metric_name
43
+ self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=epoch)
44
+
45
+ return False
46
+
47
+ def after_training(self, model: xgb.Booster) -> xgb.Booster:
48
+ if self.log_model:
49
+ if not self.model_signature:
50
+ warn(
51
+ "Model will not be logged because model signature is missing. "
52
+ "To autolog the model, please specify `model_signature` when constructing SnowflakeXgboostCallback."
53
+ )
54
+ return model
55
+
56
+ model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
57
+ self._experiment_tracking.log_model( # type: ignore[call-arg]
58
+ model=model,
59
+ model_name=model_name,
60
+ signatures={"predict": self.model_signature},
61
+ )
62
+
63
+ return model
@@ -0,0 +1,14 @@
1
+ from typing import Any, Union
2
+
3
+
4
+ def flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str = "") -> dict[str, Any]:
5
+ flat_params = {}
6
+ items = params.items() if isinstance(params, dict) else enumerate(params)
7
+ for key, value in items:
8
+ key = str(key).replace(".", "_") # Replace dots in keys to avoid collisions involving nested keys
9
+ new_prefix = f"{prefix}.{key}" if prefix else key
10
+ if isinstance(value, (dict, list)):
11
+ flat_params.update(flatten_nested_params(value, new_prefix))
12
+ else:
13
+ flat_params[new_prefix] = value
14
+ return flat_params
@@ -63,6 +63,13 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
63
63
 
64
64
  ##### Set up Python environment #####
65
65
  export PYTHONPATH=/opt/env/site-packages/
66
+ MLRS_SYSTEM_REQUIREMENTS_FILE=${{MLRS_SYSTEM_REQUIREMENTS_FILE:-"${{SYSTEM_DIR}}/requirements.txt"}}
67
+
68
+ if [ -f "${{MLRS_SYSTEM_REQUIREMENTS_FILE}}" ]; then
69
+ echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
70
+ pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
71
+ fi
72
+
66
73
  MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
67
74
  if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
68
75
  # TODO: Prevent collisions with MLRS packages using virtualenvs
@@ -454,8 +461,6 @@ class JobPayload:
454
461
  overwrite=True,
455
462
  )
456
463
  source = Path(entrypoint.file_path.parent)
457
- if not any(r.startswith("cloudpickle") for r in pip_requirements):
458
- pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
459
464
 
460
465
  elif isinstance(source, stage_utils.StagePath):
461
466
  # copy payload to stage
@@ -470,19 +475,20 @@ class JobPayload:
470
475
 
471
476
  upload_payloads(session, app_stage_path, *additional_payload_specs)
472
477
 
473
- # Upload requirements to app/ directory
474
- # TODO: Check if payload includes both a requirements.txt file and pip_requirements
478
+ if not any(r.startswith("cloudpickle") for r in pip_requirements):
479
+ pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
480
+
481
+ # Upload system scripts and requirements.txt generated by pip_requirements to system/ directory
482
+ system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
475
483
  if pip_requirements:
476
484
  # Upload requirements.txt to stage
477
485
  session.file.put_stream(
478
486
  io.BytesIO("\n".join(pip_requirements).encode()),
479
- stage_location=app_stage_path.joinpath("requirements.txt").as_posix(),
487
+ stage_location=system_stage_path.joinpath("requirements.txt").as_posix(),
480
488
  auto_compress=False,
481
489
  overwrite=True,
482
490
  )
483
491
 
484
- # Upload startup script to system/ directory within payload
485
- system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
486
492
  # TODO: Make sure payload does not include file with same name
487
493
  session.file.put_stream(
488
494
  io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
@@ -191,7 +191,7 @@ def wait_for_instances(
191
191
  logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
192
192
  return
193
193
 
194
- logger.debug(
194
+ logger.info(
195
195
  f"Waiting for instances: current_instances={total_nodes}, min_instances={min_instances}, "
196
196
  f"target_instances={target_instances}, elapsed={elapsed:.1f}s, next check in {current_interval:.1f}s"
197
197
  )
@@ -199,7 +199,7 @@ def wait_for_instances(
199
199
  current_interval = min(current_interval * 2, check_interval) # Exponential backoff
200
200
 
201
201
  raise TimeoutError(
202
- f"Timed out after {timeout}s waiting for {min_instances} instances, only " f"{total_nodes} available"
202
+ f"Timed out after {elapsed}s waiting for {min_instances} instances, only " f"{total_nodes} available"
203
203
  )
204
204
 
205
205
 
@@ -1,5 +1,4 @@
1
1
  import enum
2
- import logging
3
2
  import pathlib
4
3
  import tempfile
5
4
  import warnings
@@ -881,6 +880,7 @@ class ModelVersion(lineage_node.LineageNode):
881
880
 
882
881
  Raises:
883
882
  ValueError: Illegal external access integration arguments.
883
+ exceptions.SnowparkSQLException: if service already exists.
884
884
 
885
885
  Returns:
886
886
  If `block=True`, return result information about service creation from server.
@@ -891,16 +891,6 @@ class ModelVersion(lineage_node.LineageNode):
891
891
  subproject=_TELEMETRY_SUBPROJECT,
892
892
  )
893
893
 
894
- # Check root logger level and emit warning if needed
895
- root_logger = logging.getLogger()
896
- if root_logger.level in (logging.WARNING, logging.ERROR):
897
- warnings.warn(
898
- "Suppressing service logs. Set the log level to INFO if you would like "
899
- "verbose service logs (e.g., logging.getLogger().setLevel(logging.INFO)).",
900
- UserWarning,
901
- stacklevel=2,
902
- )
903
-
904
894
  if build_external_access_integration is not None:
905
895
  msg = (
906
896
  "`build_external_access_integration` is deprecated. "
@@ -917,39 +907,60 @@ class ModelVersion(lineage_node.LineageNode):
917
907
 
918
908
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
919
909
  image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
920
- return self._service_ops.create_service(
921
- database_name=None,
922
- schema_name=None,
923
- model_name=self._model_name,
924
- version_name=self._version_name,
925
- service_database_name=service_db_id,
926
- service_schema_name=service_schema_id,
927
- service_name=service_id,
928
- image_build_compute_pool_name=(
929
- sql_identifier.SqlIdentifier(image_build_compute_pool)
930
- if image_build_compute_pool
931
- else sql_identifier.SqlIdentifier(service_compute_pool)
932
- ),
933
- service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
934
- image_repo_database_name=image_repo_db_id,
935
- image_repo_schema_name=image_repo_schema_id,
936
- image_repo_name=image_repo_id,
937
- ingress_enabled=ingress_enabled,
938
- max_instances=max_instances,
939
- cpu_requests=cpu_requests,
940
- memory_requests=memory_requests,
941
- gpu_requests=gpu_requests,
942
- num_workers=num_workers,
943
- max_batch_rows=max_batch_rows,
944
- force_rebuild=force_rebuild,
945
- build_external_access_integrations=(
946
- None
947
- if build_external_access_integrations is None
948
- else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
949
- ),
950
- block=block,
951
- statement_params=statement_params,
952
- )
910
+
911
+ from snowflake.ml.model import event_handler
912
+ from snowflake.snowpark import exceptions
913
+
914
+ model_event_handler = event_handler.ModelEventHandler()
915
+
916
+ with model_event_handler.status("Creating model inference service", total=6, block=block) as status:
917
+ try:
918
+ result = self._service_ops.create_service(
919
+ database_name=None,
920
+ schema_name=None,
921
+ model_name=self._model_name,
922
+ version_name=self._version_name,
923
+ service_database_name=service_db_id,
924
+ service_schema_name=service_schema_id,
925
+ service_name=service_id,
926
+ image_build_compute_pool_name=(
927
+ sql_identifier.SqlIdentifier(image_build_compute_pool)
928
+ if image_build_compute_pool
929
+ else sql_identifier.SqlIdentifier(service_compute_pool)
930
+ ),
931
+ service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
932
+ image_repo=image_repo,
933
+ ingress_enabled=ingress_enabled,
934
+ max_instances=max_instances,
935
+ cpu_requests=cpu_requests,
936
+ memory_requests=memory_requests,
937
+ gpu_requests=gpu_requests,
938
+ num_workers=num_workers,
939
+ max_batch_rows=max_batch_rows,
940
+ force_rebuild=force_rebuild,
941
+ build_external_access_integrations=(
942
+ None
943
+ if build_external_access_integrations is None
944
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
945
+ ),
946
+ block=block,
947
+ statement_params=statement_params,
948
+ progress_status=status,
949
+ )
950
+ status.update(label="Model service created successfully", state="complete", expanded=False)
951
+ return result
952
+ except exceptions.SnowparkSQLException as e:
953
+ # Check if the error is because the service already exists
954
+ if "already exists" in str(e).lower() or "100132" in str(
955
+ e
956
+ ): # 100132 is Snowflake error code for object already exists
957
+ status.update("service already exists")
958
+ status.complete()
959
+ status.update(label="Service already exists", state="error", expanded=False)
960
+ raise
961
+ else:
962
+ status.update(label="Service creation failed", state="error", expanded=False)
963
+ raise
953
964
 
954
965
  @telemetry.send_api_usage_telemetry(
955
966
  project=_TELEMETRY_PROJECT,
@@ -1045,7 +1056,6 @@ class ModelVersion(lineage_node.LineageNode):
1045
1056
  )
1046
1057
  target_function_info = self._get_function_info(function_name=function_name)
1047
1058
  job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
1048
- image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
1049
1059
  output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
1050
1060
  output_table_name
1051
1061
  )
@@ -1064,9 +1074,7 @@ class ModelVersion(lineage_node.LineageNode):
1064
1074
  job_name=job_id,
1065
1075
  compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
1066
1076
  warehouse_name=sql_identifier.SqlIdentifier(warehouse),
1067
- image_repo_database_name=image_repo_db_id,
1068
- image_repo_schema_name=image_repo_schema_id,
1069
- image_repo_name=image_repo_id,
1077
+ image_repo=image_repo,
1070
1078
  output_table_database_name=output_table_db_id,
1071
1079
  output_table_schema_name=output_table_schema_id,
1072
1080
  output_table_name=output_table_id,