snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/mixins.py +71 -0
  6. snowflake/ml/_internal/utils/service_logger.py +4 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
  8. snowflake/ml/data/data_connector.py +43 -2
  9. snowflake/ml/data/data_ingestor.py +8 -0
  10. snowflake/ml/data/torch_utils.py +1 -1
  11. snowflake/ml/dataset/dataset.py +3 -2
  12. snowflake/ml/dataset/dataset_reader.py +22 -6
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/experiment_tracking.py +319 -0
  20. snowflake/ml/jobs/_utils/constants.py +1 -1
  21. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +5 -3
  23. snowflake/ml/jobs/_utils/query_helper.py +20 -0
  24. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
  25. snowflake/ml/jobs/_utils/spec_utils.py +21 -4
  26. snowflake/ml/jobs/decorators.py +18 -25
  27. snowflake/ml/jobs/job.py +137 -37
  28. snowflake/ml/jobs/manager.py +228 -153
  29. snowflake/ml/lineage/lineage_node.py +2 -2
  30. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  31. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  32. snowflake/ml/model/_client/ops/service_ops.py +324 -138
  33. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  34. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  35. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  38. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  39. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  40. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  45. snowflake/ml/model/event_handler.py +117 -0
  46. snowflake/ml/model/model_signature.py +9 -9
  47. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  48. snowflake/ml/model/target_platform.py +11 -0
  49. snowflake/ml/model/task.py +9 -0
  50. snowflake/ml/model/type_hints.py +5 -13
  51. snowflake/ml/modeling/framework/base.py +1 -1
  52. snowflake/ml/modeling/metrics/classification.py +14 -14
  53. snowflake/ml/modeling/metrics/correlation.py +19 -8
  54. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  55. snowflake/ml/modeling/metrics/ranking.py +6 -6
  56. snowflake/ml/modeling/metrics/regression.py +9 -9
  57. snowflake/ml/monitoring/explain_visualize.py +12 -5
  58. snowflake/ml/registry/_manager/model_manager.py +47 -15
  59. snowflake/ml/registry/registry.py +109 -64
  60. snowflake/ml/version.py +1 -1
  61. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
  62. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
  63. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
  64. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -955,7 +955,7 @@ class ModelOperator:
955
955
  output_with_input_features = False
956
956
  df = model_signature._convert_and_validate_local_data(X, signature.inputs, strict=strict_input_validation)
957
957
  s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
958
- self._session, df, keep_order=keep_order, features=signature.inputs
958
+ self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
959
959
  )
960
960
  else:
961
961
  keep_order = False
@@ -969,9 +969,16 @@ class ModelOperator:
969
969
 
970
970
  # Compose input and output names
971
971
  input_args = []
972
+ quoted_identifiers_ignore_case = (
973
+ snowpark_handler.SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
974
+ self._session, statement_params
975
+ )
976
+ )
977
+
972
978
  for input_feature in signature.inputs:
973
979
  col_name = identifier_rule.get_sql_identifier_from_feature(input_feature.name)
974
-
980
+ if quoted_identifiers_ignore_case:
981
+ col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
975
982
  input_args.append(col_name)
976
983
 
977
984
  returns = []
@@ -1051,7 +1058,9 @@ class ModelOperator:
1051
1058
 
1052
1059
  # Get final result
1053
1060
  if not isinstance(X, dataframe.DataFrame):
1054
- return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
1061
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
1062
+ df_res, features=signature.outputs, statement_params=statement_params
1063
+ )
1055
1064
  else:
1056
1065
  return df_res
1057
1066
 
@@ -1,4 +1,5 @@
1
1
  import dataclasses
2
+ import enum
2
3
  import hashlib
3
4
  import logging
4
5
  import pathlib
@@ -22,20 +23,63 @@ module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY
22
23
  module_logger.propagate = False
23
24
 
24
25
 
26
+ class DeploymentStep(enum.Enum):
27
+ MODEL_BUILD = ("model-build", "model_build_")
28
+ MODEL_INFERENCE = ("model-inference", None)
29
+ MODEL_LOGGING = ("model-logging", "model_logging_")
30
+
31
+ def __init__(self, container_name: str, service_name_prefix: Optional[str]) -> None:
32
+ self._container_name = container_name
33
+ self._service_name_prefix = service_name_prefix
34
+
35
+ @property
36
+ def container_name(self) -> str:
37
+ """Get the container name for the deployment step."""
38
+ return self._container_name
39
+
40
+ @property
41
+ def service_name_prefix(self) -> Optional[str]:
42
+ """Get the service name prefix for the deployment step."""
43
+ return self._service_name_prefix
44
+
45
+
25
46
  @dataclasses.dataclass
26
47
  class ServiceLogInfo:
27
48
  database_name: Optional[sql_identifier.SqlIdentifier]
28
49
  schema_name: Optional[sql_identifier.SqlIdentifier]
29
50
  service_name: sql_identifier.SqlIdentifier
30
- container_name: str
51
+ deployment_step: DeploymentStep
31
52
  instance_id: str = "0"
53
+ log_color: service_logger.LogColor = service_logger.LogColor.GREY
32
54
 
33
55
  def __post_init__(self) -> None:
34
56
  # service name used in logs for display
35
57
  self.display_service_name = sql_identifier.get_fully_qualified_name(
36
- self.database_name, self.schema_name, self.service_name
58
+ self.database_name,
59
+ self.schema_name,
60
+ self.service_name,
37
61
  )
38
62
 
63
+ def fetch_logs(
64
+ self,
65
+ service_client: service_sql.ServiceSQLClient,
66
+ offset: int,
67
+ statement_params: Optional[dict[str, Any]],
68
+ ) -> tuple[str, int]:
69
+ service_logs = service_client.get_service_logs(
70
+ database_name=self.database_name,
71
+ schema_name=self.schema_name,
72
+ service_name=self.service_name,
73
+ container_name=self.deployment_step.container_name,
74
+ statement_params=statement_params,
75
+ )
76
+
77
+ # return only new logs starting after the offset
78
+ new_logs = service_logs[offset:]
79
+ new_offset = max(offset, len(service_logs))
80
+
81
+ return new_logs, new_offset
82
+
39
83
 
40
84
  @dataclasses.dataclass
41
85
  class ServiceLogMetadata:
@@ -43,8 +87,47 @@ class ServiceLogMetadata:
43
87
  service: ServiceLogInfo
44
88
  service_status: Optional[service_sql.ServiceStatus]
45
89
  is_model_build_service_done: bool
90
+ is_model_logger_service_done: bool
46
91
  log_offset: int
47
92
 
93
+ def transition_service_log_metadata(
94
+ self,
95
+ to_service: ServiceLogInfo,
96
+ msg: str,
97
+ is_model_build_service_done: bool,
98
+ is_model_logger_service_done: bool,
99
+ propagate: bool = False,
100
+ ) -> None:
101
+ to_service_logger = service_logger.get_logger(
102
+ f"{to_service.display_service_name}-{to_service.instance_id}",
103
+ to_service.log_color,
104
+ )
105
+ to_service_logger.propagate = propagate
106
+ self.service_logger = to_service_logger
107
+ self.service = to_service
108
+ self.service_status = None
109
+ self.is_model_build_service_done = is_model_build_service_done
110
+ self.is_model_logger_service_done = is_model_logger_service_done
111
+ self.log_offset = 0
112
+ block_size = 180
113
+ module_logger.info(msg)
114
+ module_logger.info("-" * block_size)
115
+
116
+
117
+ @dataclasses.dataclass
118
+ class HFModelArgs:
119
+ hf_model_name: str
120
+ hf_task: Optional[str] = None
121
+ hf_tokenizer: Optional[str] = None
122
+ hf_revision: Optional[str] = None
123
+ hf_token: Optional[str] = None
124
+ hf_trust_remote_code: bool = False
125
+ hf_model_kwargs: Optional[dict[str, Any]] = None
126
+ pip_requirements: Optional[list[str]] = None
127
+ conda_dependencies: Optional[list[str]] = None
128
+ comment: Optional[str] = None
129
+ warehouse: Optional[str] = None
130
+
48
131
 
49
132
  class ServiceOperator:
50
133
  """Service operator for container services logic."""
@@ -109,6 +192,8 @@ class ServiceOperator:
109
192
  build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
110
193
  block: bool,
111
194
  statement_params: Optional[dict[str, Any]] = None,
195
+ # hf model
196
+ hf_model_args: Optional[HFModelArgs] = None,
112
197
  ) -> Union[str, async_job.AsyncJob]:
113
198
 
114
199
  # Fall back to the registry's database and schema if not provided
@@ -153,6 +238,21 @@ class ServiceOperator:
153
238
  num_workers=num_workers,
154
239
  max_batch_rows=max_batch_rows,
155
240
  )
241
+ if hf_model_args:
242
+ # hf model
243
+ self._model_deployment_spec.add_hf_logger_spec(
244
+ hf_model_name=hf_model_args.hf_model_name,
245
+ hf_task=hf_model_args.hf_task,
246
+ hf_token=hf_model_args.hf_token,
247
+ hf_tokenizer=hf_model_args.hf_tokenizer,
248
+ hf_revision=hf_model_args.hf_revision,
249
+ hf_trust_remote_code=hf_model_args.hf_trust_remote_code,
250
+ pip_requirements=hf_model_args.pip_requirements,
251
+ conda_dependencies=hf_model_args.conda_dependencies,
252
+ comment=hf_model_args.comment,
253
+ warehouse=hf_model_args.warehouse,
254
+ **(hf_model_args.hf_model_kwargs if hf_model_args.hf_model_kwargs else {}),
255
+ )
156
256
  spec_yaml_str_or_path = self._model_deployment_spec.save()
157
257
  if self._workspace:
158
258
  assert stage_path is not None
@@ -187,22 +287,47 @@ class ServiceOperator:
187
287
  )
188
288
 
189
289
  # stream service logs in a thread
190
- model_build_service_name = sql_identifier.SqlIdentifier(self._get_model_build_service_name(query_id))
290
+ model_build_service_name = sql_identifier.SqlIdentifier(
291
+ self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_BUILD)
292
+ )
191
293
  model_build_service = ServiceLogInfo(
192
294
  database_name=service_database_name,
193
295
  schema_name=service_schema_name,
194
296
  service_name=model_build_service_name,
195
- container_name="model-build",
297
+ deployment_step=DeploymentStep.MODEL_BUILD,
298
+ log_color=service_logger.LogColor.GREEN,
196
299
  )
197
300
  model_inference_service = ServiceLogInfo(
198
301
  database_name=service_database_name,
199
302
  schema_name=service_schema_name,
200
303
  service_name=service_name,
201
- container_name="model-inference",
304
+ deployment_step=DeploymentStep.MODEL_INFERENCE,
305
+ log_color=service_logger.LogColor.BLUE,
202
306
  )
203
- services = [model_build_service, model_inference_service]
307
+
308
+ model_logger_service: Optional[ServiceLogInfo] = None
309
+ if hf_model_args:
310
+ model_logger_service_name = sql_identifier.SqlIdentifier(
311
+ self._get_service_id_from_deployment_step(query_id, DeploymentStep.MODEL_LOGGING)
312
+ )
313
+
314
+ model_logger_service = ServiceLogInfo(
315
+ database_name=service_database_name,
316
+ schema_name=service_schema_name,
317
+ service_name=model_logger_service_name,
318
+ deployment_step=DeploymentStep.MODEL_LOGGING,
319
+ log_color=service_logger.LogColor.ORANGE,
320
+ )
321
+
322
+ # start service log streaming
204
323
  log_thread = self._start_service_log_streaming(
205
- async_job, services, model_inference_service_exists, force_rebuild, statement_params
324
+ async_job=async_job,
325
+ model_logger_service=model_logger_service,
326
+ model_build_service=model_build_service,
327
+ model_inference_service=model_inference_service,
328
+ model_inference_service_exists=model_inference_service_exists,
329
+ force_rebuild=force_rebuild,
330
+ statement_params=statement_params,
206
331
  )
207
332
 
208
333
  if block:
@@ -217,17 +342,22 @@ class ServiceOperator:
217
342
  def _start_service_log_streaming(
218
343
  self,
219
344
  async_job: snowpark.AsyncJob,
220
- services: list[ServiceLogInfo],
345
+ model_logger_service: Optional[ServiceLogInfo],
346
+ model_build_service: ServiceLogInfo,
347
+ model_inference_service: ServiceLogInfo,
221
348
  model_inference_service_exists: bool,
222
349
  force_rebuild: bool,
223
350
  statement_params: Optional[dict[str, Any]] = None,
224
351
  ) -> threading.Thread:
225
352
  """Start the service log streaming in a separate thread."""
353
+ # TODO: create a DAG of services and stream logs in the order of the DAG
226
354
  log_thread = threading.Thread(
227
355
  target=self._stream_service_logs,
228
356
  args=(
229
357
  async_job,
230
- services,
358
+ model_logger_service,
359
+ model_build_service,
360
+ model_inference_service,
231
361
  model_inference_service_exists,
232
362
  force_rebuild,
233
363
  statement_params,
@@ -236,151 +366,199 @@ class ServiceOperator:
236
366
  log_thread.start()
237
367
  return log_thread
238
368
 
239
- def _stream_service_logs(
369
+ def _fetch_log_and_update_meta(
240
370
  self,
241
- async_job: snowpark.AsyncJob,
242
- services: list[ServiceLogInfo],
243
- model_inference_service_exists: bool,
244
371
  force_rebuild: bool,
372
+ service_log_meta: ServiceLogMetadata,
373
+ model_build_service: ServiceLogInfo,
374
+ model_inference_service: ServiceLogInfo,
245
375
  statement_params: Optional[dict[str, Any]] = None,
246
376
  ) -> None:
247
- """Stream service logs while the async job is running."""
377
+ """Helper function to fetch logs and update the service log metadata if needed.
378
+
379
+ This function checks the service status and fetches logs if the service exists.
380
+ It also updates the service log metadata with the
381
+ new service status and logs.
382
+ If the service is done, it transitions the service log metadata.
383
+
384
+ Args:
385
+ force_rebuild: Whether to force rebuild the model build image.
386
+ service_log_meta: The ServiceLogMetadata holds the state of the service log metadata.
387
+ model_build_service: The ServiceLogInfo for the model build service.
388
+ model_inference_service: The ServiceLogInfo for the model inference service.
389
+ statement_params: The statement parameters to use for the service client.
390
+ """
391
+
392
+ service = service_log_meta.service
393
+ # check if using an existing model build image
394
+ if (
395
+ service.deployment_step == DeploymentStep.MODEL_BUILD
396
+ and not force_rebuild
397
+ and service_log_meta.is_model_logger_service_done
398
+ and not service_log_meta.is_model_build_service_done
399
+ ):
400
+ model_build_service_exists = self._check_if_service_exists(
401
+ database_name=service.database_name,
402
+ schema_name=service.schema_name,
403
+ service_name=service.service_name,
404
+ statement_params=statement_params,
405
+ )
406
+ new_model_inference_service_exists = self._check_if_service_exists(
407
+ database_name=model_inference_service.database_name,
408
+ schema_name=model_inference_service.schema_name,
409
+ service_name=model_inference_service.service_name,
410
+ statement_params=statement_params,
411
+ )
412
+ if not model_build_service_exists and new_model_inference_service_exists:
413
+ service_log_meta.transition_service_log_metadata(
414
+ model_inference_service,
415
+ "Model build is not rebuilding the inference image, but using a previously built image.",
416
+ is_model_build_service_done=True,
417
+ is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
418
+ )
248
419
 
249
- def fetch_logs(service: ServiceLogInfo, offset: int) -> tuple[str, int]:
250
- service_logs = self._service_client.get_service_logs(
420
+ try:
421
+ statuses = self._service_client.get_service_container_statuses(
251
422
  database_name=service.database_name,
252
423
  schema_name=service.schema_name,
253
424
  service_name=service.service_name,
254
- container_name=service.container_name,
425
+ include_message=True,
255
426
  statement_params=statement_params,
256
427
  )
428
+ service_status = statuses[0].service_status
429
+ except exceptions.SnowparkSQLException:
430
+ # If the service is not found, log that the service is not found
431
+ # and wait for a few seconds before returning
432
+ module_logger.info(f"Service status for service {service.display_service_name} not found.")
433
+ time.sleep(5)
434
+ return
257
435
 
258
- # return only new logs starting after the offset
259
- if len(service_logs) > offset:
260
- new_logs = service_logs[offset:]
261
- new_offset = len(service_logs)
262
- else:
263
- new_logs = ""
264
- new_offset = offset
436
+ # Case 1: service_status is PENDING or the service_status changed
437
+ if (service_status != service_sql.ServiceStatus.RUNNING) or (service_status != service_log_meta.service_status):
438
+ service_log_meta.service_status = service_status
265
439
 
266
- return new_logs, new_offset
440
+ if service.deployment_step == DeploymentStep.MODEL_BUILD:
441
+ module_logger.info(
442
+ f"Image build service {service.display_service_name} is "
443
+ f"{service_log_meta.service_status.value}."
444
+ )
445
+ elif service.deployment_step == DeploymentStep.MODEL_INFERENCE:
446
+ module_logger.info(
447
+ f"Inference service {service.display_service_name} is {service_log_meta.service_status.value}."
448
+ )
449
+ elif service.deployment_step == DeploymentStep.MODEL_LOGGING:
450
+ module_logger.info(
451
+ f"Model logger service {service.display_service_name} is "
452
+ f"{service_log_meta.service_status.value}."
453
+ )
454
+ for status in statuses:
455
+ if status.instance_id is not None:
456
+ instance_status, container_status = None, None
457
+ if status.instance_status is not None:
458
+ instance_status = status.instance_status.value
459
+ if status.container_status is not None:
460
+ container_status = status.container_status.value
461
+ module_logger.info(
462
+ f"Instance[{status.instance_id}]: "
463
+ f"instance status: {instance_status}, "
464
+ f"container status: {container_status}, "
465
+ f"message: {status.message}"
466
+ )
467
+ time.sleep(5)
267
468
 
268
- def set_service_log_metadata_to_model_inference(
269
- meta: ServiceLogMetadata, inference_service: ServiceLogInfo, msg: str
270
- ) -> None:
271
- model_inference_service_logger = service_logger.get_logger( # InferenceServiceName-InstanceId
272
- f"{inference_service.display_service_name}-{inference_service.instance_id}",
273
- service_logger.LogColor.BLUE,
469
+ # Case 2: service_status is RUNNING
470
+ # stream logs and update the log offset
471
+ if service_status == service_sql.ServiceStatus.RUNNING:
472
+ new_logs, new_offset = service.fetch_logs(
473
+ self._service_client,
474
+ service_log_meta.log_offset,
475
+ statement_params=statement_params,
274
476
  )
275
- model_inference_service_logger.propagate = False
276
- meta.service_logger = model_inference_service_logger
277
- meta.service = inference_service
278
- meta.service_status = None
279
- meta.is_model_build_service_done = True
280
- meta.log_offset = 0
281
- block_size = 180
282
- module_logger.info(msg)
283
- module_logger.info("-" * block_size)
284
-
285
- model_build_service, model_inference_service = services[0], services[1]
477
+ if new_logs:
478
+ service_log_meta.service_logger.info(new_logs)
479
+ service_log_meta.log_offset = new_offset
480
+
481
+ # Case 3: service_status is DONE
482
+ if service_status == service_sql.ServiceStatus.DONE:
483
+ # check if model logger service is done
484
+ # and transition the service log metadata to the model image build service
485
+ if service.deployment_step == DeploymentStep.MODEL_LOGGING:
486
+ service_log_meta.transition_service_log_metadata(
487
+ model_build_service,
488
+ f"Model Logger service {service.display_service_name} complete.",
489
+ is_model_build_service_done=False,
490
+ is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
491
+ )
492
+ # check if model build service is done
493
+ # and transition the service log metadata to the model inference service
494
+ elif service.deployment_step == DeploymentStep.MODEL_BUILD:
495
+ service_log_meta.transition_service_log_metadata(
496
+ model_inference_service,
497
+ f"Image build service {service.display_service_name} complete.",
498
+ is_model_build_service_done=True,
499
+ is_model_logger_service_done=service_log_meta.is_model_logger_service_done,
500
+ )
501
+ else:
502
+ module_logger.warning(f"Service {service.display_service_name} is done, but not transitioning.")
503
+
504
+ def _stream_service_logs(
505
+ self,
506
+ async_job: snowpark.AsyncJob,
507
+ model_logger_service: Optional[ServiceLogInfo],
508
+ model_build_service: ServiceLogInfo,
509
+ model_inference_service: ServiceLogInfo,
510
+ model_inference_service_exists: bool,
511
+ force_rebuild: bool,
512
+ statement_params: Optional[dict[str, Any]] = None,
513
+ ) -> None:
514
+ """Stream service logs while the async job is running."""
515
+
286
516
  model_build_service_logger = service_logger.get_logger( # BuildJobName
287
- model_build_service.display_service_name, service_logger.LogColor.GREEN
517
+ model_build_service.display_service_name,
518
+ model_build_service.log_color,
288
519
  )
289
520
  model_build_service_logger.propagate = False
290
- service_log_meta = ServiceLogMetadata(
291
- service_logger=model_build_service_logger,
292
- service=model_build_service,
293
- service_status=None,
294
- is_model_build_service_done=False,
295
- log_offset=0,
296
- )
521
+ if model_logger_service:
522
+ model_logger_service_logger = service_logger.get_logger( # ModelLoggerName
523
+ model_logger_service.display_service_name,
524
+ model_logger_service.log_color,
525
+ )
526
+ model_logger_service_logger.propagate = False
527
+
528
+ service_log_meta = ServiceLogMetadata(
529
+ service_logger=model_logger_service_logger,
530
+ service=model_logger_service,
531
+ service_status=None,
532
+ is_model_build_service_done=False,
533
+ is_model_logger_service_done=False,
534
+ log_offset=0,
535
+ )
536
+ else:
537
+ service_log_meta = ServiceLogMetadata(
538
+ service_logger=model_build_service_logger,
539
+ service=model_build_service,
540
+ service_status=None,
541
+ is_model_build_service_done=False,
542
+ is_model_logger_service_done=True,
543
+ log_offset=0,
544
+ )
545
+
297
546
  while not async_job.is_done():
298
547
  if model_inference_service_exists:
299
548
  time.sleep(5)
300
549
  continue
301
550
 
302
551
  try:
303
- # check if using an existing model build image
304
- if not force_rebuild and not service_log_meta.is_model_build_service_done:
305
- model_build_service_exists = self._check_if_service_exists(
306
- database_name=model_build_service.database_name,
307
- schema_name=model_build_service.schema_name,
308
- service_name=model_build_service.service_name,
309
- statement_params=statement_params,
310
- )
311
- new_model_inference_service_exists = self._check_if_service_exists(
312
- database_name=model_inference_service.database_name,
313
- schema_name=model_inference_service.schema_name,
314
- service_name=model_inference_service.service_name,
315
- statement_params=statement_params,
316
- )
317
- if not model_build_service_exists and new_model_inference_service_exists:
318
- set_service_log_metadata_to_model_inference(
319
- service_log_meta,
320
- model_inference_service,
321
- (
322
- "Model Inference image build is not rebuilding the image, but using a previously built "
323
- "image."
324
- ),
325
- )
326
- continue
327
-
328
- statuses = self._service_client.get_service_container_statuses(
329
- database_name=service_log_meta.service.database_name,
330
- schema_name=service_log_meta.service.schema_name,
331
- service_name=service_log_meta.service.service_name,
332
- include_message=True,
552
+ # fetch logs for the service
553
+ # (model logging, model build, or model inference)
554
+ # upon completion, transition to the next service if any
555
+ self._fetch_log_and_update_meta(
556
+ service_log_meta=service_log_meta,
557
+ force_rebuild=force_rebuild,
558
+ model_build_service=model_build_service,
559
+ model_inference_service=model_inference_service,
333
560
  statement_params=statement_params,
334
561
  )
335
- service_status = statuses[0].service_status
336
- if (service_status != service_sql.ServiceStatus.RUNNING) or (
337
- service_status != service_log_meta.service_status
338
- ):
339
- service_log_meta.service_status = service_status
340
- module_logger.info(
341
- f"{'Inference' if service_log_meta.is_model_build_service_done else 'Image build'} service "
342
- f"{service_log_meta.service.display_service_name} is "
343
- f"{service_log_meta.service_status.value}."
344
- )
345
- for status in statuses:
346
- if status.instance_id is not None:
347
- instance_status, container_status = None, None
348
- if status.instance_status is not None:
349
- instance_status = status.instance_status.value
350
- if status.container_status is not None:
351
- container_status = status.container_status.value
352
- module_logger.info(
353
- f"Instance[{status.instance_id}]: "
354
- f"instance status: {instance_status}, "
355
- f"container status: {container_status}, "
356
- f"message: {status.message}"
357
- )
358
-
359
- new_logs, new_offset = fetch_logs(
360
- service_log_meta.service,
361
- service_log_meta.log_offset,
362
- )
363
- if new_logs:
364
- service_log_meta.service_logger.info(new_logs)
365
- service_log_meta.log_offset = new_offset
366
-
367
- # check if model build service is done
368
- if not service_log_meta.is_model_build_service_done:
369
- statuses = self._service_client.get_service_container_statuses(
370
- database_name=model_build_service.database_name,
371
- schema_name=model_build_service.schema_name,
372
- service_name=model_build_service.service_name,
373
- include_message=False,
374
- statement_params=statement_params,
375
- )
376
- service_status = statuses[0].service_status
377
-
378
- if service_status == service_sql.ServiceStatus.DONE:
379
- set_service_log_metadata_to_model_inference(
380
- service_log_meta,
381
- model_inference_service,
382
- f"Image build service {model_build_service.display_service_name} complete.",
383
- )
384
562
  except Exception as ex:
385
563
  pattern = r"002003 \(02000\)" # error code: service does not exist
386
564
  is_snowpark_sql_exception = isinstance(ex, exceptions.SnowparkSQLException)
@@ -388,8 +566,7 @@ class ServiceOperator:
388
566
  matches_pattern = service_log_meta.service_status is None and re.search(pattern, str(ex)) is not None
389
567
  if not (is_snowpark_sql_exception and (contains_msg or matches_pattern)):
390
568
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
391
-
392
- time.sleep(5)
569
+ time.sleep(5)
393
570
 
394
571
  if model_inference_service_exists:
395
572
  module_logger.info(
@@ -397,7 +574,10 @@ class ServiceOperator:
397
574
  )
398
575
  else:
399
576
  self._finalize_logs(
400
- service_log_meta.service_logger, service_log_meta.service, service_log_meta.log_offset, statement_params
577
+ service_log_meta.service_logger,
578
+ service_log_meta.service,
579
+ service_log_meta.log_offset,
580
+ statement_params,
401
581
  )
402
582
 
403
583
  def _finalize_logs(
@@ -414,7 +594,7 @@ class ServiceOperator:
414
594
  database_name=service.database_name,
415
595
  schema_name=service.schema_name,
416
596
  service_name=service.service_name,
417
- container_name=service.container_name,
597
+ container_name=service.deployment_step.container_name,
418
598
  statement_params=statement_params,
419
599
  )
420
600
 
@@ -424,13 +604,17 @@ class ServiceOperator:
424
604
  module_logger.warning(f"Caught an exception when logging: {repr(ex)}")
425
605
 
426
606
  @staticmethod
427
- def _get_model_build_service_name(query_id: str) -> str:
428
- """Get the model build service name through the server-side logic."""
607
+ def _get_service_id_from_deployment_step(query_id: str, deployment_step: DeploymentStep) -> str:
608
+ """Get the service ID through the server-side logic."""
429
609
  uuid = query_id.replace("-", "")
430
610
  big_int = int(uuid, 16)
431
611
  md5_hash = hashlib.md5(str(big_int).encode()).hexdigest()
432
612
  identifier = md5_hash[:8]
433
- return ("model_build_" + identifier).upper()
613
+ service_name_prefix = deployment_step.service_name_prefix
614
+ if service_name_prefix is None:
615
+ # raise an exception if the service name prefix is None
616
+ raise ValueError(f"Service name prefix is {service_name_prefix} for deployment step {deployment_step}.")
617
+ return (service_name_prefix + identifier).upper()
434
618
 
435
619
  def _check_if_service_exists(
436
620
  self,
@@ -518,7 +702,7 @@ class ServiceOperator:
518
702
  output_with_input_features = False
519
703
  df = model_signature._convert_and_validate_local_data(X, signature.inputs)
520
704
  s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
521
- self._session, df, keep_order=keep_order, features=signature.inputs
705
+ self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
522
706
  )
523
707
  else:
524
708
  keep_order = False
@@ -630,7 +814,9 @@ class ServiceOperator:
630
814
 
631
815
  # get final result
632
816
  if not isinstance(X, dataframe.DataFrame):
633
- return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
817
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
818
+ df_res, features=signature.outputs, statement_params=statement_params
819
+ )
634
820
  else:
635
821
  return df_res
636
822
 
@@ -347,7 +347,7 @@ class ModelDeploymentSpec:
347
347
  hf_model = model_deployment_spec_schema.HuggingFaceModel(
348
348
  hf_model_name=hf_model_name,
349
349
  task=hf_task,
350
- hf_token=hf_token,
350
+ token=hf_token,
351
351
  tokenizer=hf_tokenizer,
352
352
  trust_remote_code=hf_trust_remote_code,
353
353
  revision=hf_revision,