snowflake-ml-python 1.9.2__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/service_logger.py +31 -17
- snowflake/ml/experiment/callback/lightgbm.py +55 -0
- snowflake/ml/experiment/callback/xgboost.py +63 -0
- snowflake/ml/experiment/utils.py +14 -0
- snowflake/ml/jobs/_utils/payload_utils.py +13 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +56 -48
- snowflake/ml/model/_client/ops/service_ops.py +177 -12
- snowflake/ml/model/event_handler.py +87 -18
- snowflake/ml/model/models/huggingface_pipeline.py +71 -49
- snowflake/ml/model/type_hints.py +26 -1
- snowflake/ml/registry/_manager/model_manager.py +30 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
- snowflake/ml/registry/registry.py +0 -19
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +505 -492
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +20 -17
- snowflake/ml/experiment/callback.py +0 -121
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
|
@@ -10,7 +10,7 @@ snowflake/cortex/_sse_client.py,sha256=sLYgqAfTOPADCnaWH2RWAJi8KbU_7gSRsTUDcDD5T
|
|
|
10
10
|
snowflake/cortex/_summarize.py,sha256=7GH8zqfIdOiHA5w4b6EvJEKEWhaTrL4YA6iDGbn7BNM,1307
|
|
11
11
|
snowflake/cortex/_translate.py,sha256=9ZGjvAnJFisbzJ_bXnt4pyug5UzhHJRXW8AhGQEersM,1652
|
|
12
12
|
snowflake/cortex/_util.py,sha256=krNTpbkFLXwdFqy1bd0xi7ZmOzOHRnIfHdQCPiLZJxk,3288
|
|
13
|
-
snowflake/ml/version.py,sha256=
|
|
13
|
+
snowflake/ml/version.py,sha256=B5r4kxP_Y86tCkMrzCJRl5C4J8HJqV-KhqQGS6r0Klo,99
|
|
14
14
|
snowflake/ml/_internal/env.py,sha256=EY_2KVe8oR3LgKWdaeRb5rRU-NDNXJppPDsFJmMZUUY,265
|
|
15
15
|
snowflake/ml/_internal/env_utils.py,sha256=x6ID94g6FYoMX3afp0zoUHzBvuvPyiE2F6RDpxx5Cq0,30967
|
|
16
16
|
snowflake/ml/_internal/file_utils.py,sha256=7sA6loOeSfmGP4yx16P4usT9ZtRqG3ycnXu7_Tk7dOs,14206
|
|
@@ -45,7 +45,7 @@ snowflake/ml/_internal/utils/parallelize.py,sha256=l8Zjo-hp8zqoLgHxBlpz9Zmn2Z-MR
|
|
|
45
45
|
snowflake/ml/_internal/utils/pkg_version_utils.py,sha256=EaY_3IsVOZ9BCH28F5VLjp-0AiEqDlL7L715vkPsgrY,5149
|
|
46
46
|
snowflake/ml/_internal/utils/query_result_checker.py,sha256=1PR41Xn9BUIXvp-UmJ9FgEbj8WfgU7RUhz3PqvvVQ5E,10656
|
|
47
47
|
snowflake/ml/_internal/utils/result.py,sha256=59Sz6MvhjakUNiONwg9oi2544AmORCJR3XyWTxY2vP0,2405
|
|
48
|
-
snowflake/ml/_internal/utils/service_logger.py,sha256=
|
|
48
|
+
snowflake/ml/_internal/utils/service_logger.py,sha256=LmADyxsSE3-TYBX1gCYtxvaEDdH_Lf6d5gRt44uue0I,6267
|
|
49
49
|
snowflake/ml/_internal/utils/snowflake_env.py,sha256=k4ddzs8iJpRpVvgbbOtU8j4fUvqa77Awk65EJ5j2uxk,4253
|
|
50
50
|
snowflake/ml/_internal/utils/snowpark_dataframe_utils.py,sha256=tm2leAu_oDTNUQZJ98UpKtS79k-A-c72pKxd-8AE-tg,6353
|
|
51
51
|
snowflake/ml/_internal/utils/sql_identifier.py,sha256=YHIwXpb8E1U6LVUVpT8q7s9ZygONAXKPVMD4IucwXx8,4669
|
|
@@ -65,13 +65,15 @@ snowflake/ml/dataset/dataset_metadata.py,sha256=lcNvugBkP8YEkGMQqaV8SlHs5mwUKsUS
|
|
|
65
65
|
snowflake/ml/dataset/dataset_reader.py,sha256=mZsG9HyWUGgfotrGkLrunyEsOm_659mH-Sn2OyG6A-Q,5036
|
|
66
66
|
snowflake/ml/experiment/__init__.py,sha256=r7qdyPd3jwxzqvksim2ju5j_LrnYQrta0ZI6XpWUqmc,109
|
|
67
67
|
snowflake/ml/experiment/_experiment_info.py,sha256=iaJ65x6nzBYJ5djleSOzBtMpZUJCUDlRpaDw0pu-dcU,2533
|
|
68
|
-
snowflake/ml/experiment/callback.py,sha256=I1U-1kXLFDTdnVy-xr3H7-xuoPrPoZmAv2VSYvvGRXU,5047
|
|
69
68
|
snowflake/ml/experiment/experiment_tracking.py,sha256=ljY2HA1E724qnBmvuEzQA8o3ZT0XMJkeIYJ97vXPx5A,11316
|
|
69
|
+
snowflake/ml/experiment/utils.py,sha256=3bpbkilc5vvFjnti-kcyhhjAd9Ga3LqiKqJDwORiATY,628
|
|
70
70
|
snowflake/ml/experiment/_client/experiment_tracking_sql_client.py,sha256=rdCBHRqTYW6I2ztCpO-Zyb9nd_0HV26QdpGMDwxZ144,4446
|
|
71
71
|
snowflake/ml/experiment/_entities/__init__.py,sha256=ThrslBFuDxOUvdS8j_bVmEaEAms8nR1aY0ocYFnVPFg,155
|
|
72
72
|
snowflake/ml/experiment/_entities/experiment.py,sha256=lKmQj59K8fGDWVwRqeIesxorrChb-m78vX_WUmI7PV0,225
|
|
73
73
|
snowflake/ml/experiment/_entities/run.py,sha256=_bWt1YpP8iulg5jeBXMXw8zGZHr9zSE9IVIBHcCdfto,2293
|
|
74
74
|
snowflake/ml/experiment/_entities/run_metadata.py,sha256=j8V2N6QBAx4TP4h7MLIPXqquYI8KyNZnkW6wzm-peuY,1589
|
|
75
|
+
snowflake/ml/experiment/callback/lightgbm.py,sha256=JypczGEpvAtYmXT4785Obny7B2-zNkpBurnAWFVIM-Y,2368
|
|
76
|
+
snowflake/ml/experiment/callback/xgboost.py,sha256=RiXL6ft4GOwKjE_POJcNgon44pq3BIOAy2SYmAwJMuc,2384
|
|
75
77
|
snowflake/ml/feature_store/__init__.py,sha256=MJr2Gp_EimDgDxD6DtenOEdLTzg6NYPfdNiPM-5rEtw,406
|
|
76
78
|
snowflake/ml/feature_store/access_manager.py,sha256=Q5ImMXRY8WA5X5dpBMzHnIJmeyKVShjNAlbn3cQb4N8,10654
|
|
77
79
|
snowflake/ml/feature_store/entity.py,sha256=ViOSlqCV17ouiO4iH-_KvkvJZqSzpf-nfsjijG6G1Uk,4047
|
|
@@ -111,30 +113,30 @@ snowflake/ml/jobs/manager.py,sha256=M_qhnAdMDYPWL2hQscDQqzeavzEricQ5WjztcGn5XGo,
|
|
|
111
113
|
snowflake/ml/jobs/_utils/constants.py,sha256=sdidOyW2X81u0E30DU4K5aPjBTzNQmYtVyZ9D8mQaL4,4066
|
|
112
114
|
snowflake/ml/jobs/_utils/function_payload_utils.py,sha256=4LBaStMdhRxcqwRkwFje-WwiEKRWnBfkaOYouF3N3Kg,1308
|
|
113
115
|
snowflake/ml/jobs/_utils/interop_utils.py,sha256=7mODMTjKCLXkJloACG6_9b2wvmRgjXF0Jx3wpWYyJeA,21413
|
|
114
|
-
snowflake/ml/jobs/_utils/payload_utils.py,sha256=
|
|
116
|
+
snowflake/ml/jobs/_utils/payload_utils.py,sha256=Qhj4NRrZG1Wx3GijyiFTsJyIv6fdgPfREmvQ5uDybvw,28870
|
|
115
117
|
snowflake/ml/jobs/_utils/query_helper.py,sha256=h5s-_MgHc_f9AmXD5C06frHdP84n9Rmevb1Yu6R1w7s,910
|
|
116
118
|
snowflake/ml/jobs/_utils/spec_utils.py,sha256=VhdLXtDJXi5RJpC4exUUmx2gIb39qu-SQ5VraQi4KLc,13429
|
|
117
119
|
snowflake/ml/jobs/_utils/stage_utils.py,sha256=frjXVvnzFIJCoCWeLF_5x6LsKMq20vp4q1fZvwbXONc,4734
|
|
118
120
|
snowflake/ml/jobs/_utils/types.py,sha256=jXePdeg_KWVSDzs-afRTNx0m4U4MRdRF0rZxDobuNq8,2346
|
|
119
121
|
snowflake/ml/jobs/_utils/scripts/constants.py,sha256=YyIWZqQPYOTtgCY6SfyJjk2A98I5RQVmrOuLtET5Pqg,173
|
|
120
122
|
snowflake/ml/jobs/_utils/scripts/get_instance_ip.py,sha256=DmWs5cVpNmUcrqnwhrUvxE5PycDWFN88Pdut8vFDHPg,5293
|
|
121
|
-
snowflake/ml/jobs/_utils/scripts/mljob_launcher.py,sha256=
|
|
123
|
+
snowflake/ml/jobs/_utils/scripts/mljob_launcher.py,sha256=TgadoKeIG0he3kxeUij_-K9SYIrS6iKGzO0JT2d-O3k,14954
|
|
122
124
|
snowflake/ml/jobs/_utils/scripts/signal_workers.py,sha256=AR1Pylkm4-FGh10WXfrCtcxaV0rI7IQ2ZiO0Li7zZ3U,7433
|
|
123
125
|
snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py,sha256=SeJ8v5XDriwHAjIGpcQkwVP-f-lO9QIdVjVD7Fkgafs,7893
|
|
124
126
|
snowflake/ml/lineage/__init__.py,sha256=8p1YGynC-qOxAZ8jZX2z84Reg5bv1NoJMoJmNJCrzI4,65
|
|
125
127
|
snowflake/ml/lineage/lineage_node.py,sha256=vmikk4qaZuVFhQqW-VM6DuW4tDvmQlNbACvIVZEamcU,5830
|
|
126
128
|
snowflake/ml/model/__init__.py,sha256=EvPtblqPN6_T6dyVfaYUxCfo_M7D2CQ1OR5giIH4TsQ,314
|
|
127
129
|
snowflake/ml/model/custom_model.py,sha256=fDhMObqlyzD_qQG1Bq6HHkBN1w3Qzg9e81JWPiqRfc4,12249
|
|
128
|
-
snowflake/ml/model/event_handler.py,sha256=
|
|
130
|
+
snowflake/ml/model/event_handler.py,sha256=pojleQVM9TPNeDvliTvon2Sfxqbf2WWxrOebo1SaEHo,7211
|
|
129
131
|
snowflake/ml/model/model_signature.py,sha256=RH62vv4YmrQugTXLsh6kyuzfTs9_yz8a0TMkBR67XKY,32324
|
|
130
132
|
snowflake/ml/model/target_platform.py,sha256=H5d-wtuKQyVlq9x33vPtYZAlR5ka0ytcKRYgwlKl0bQ,390
|
|
131
133
|
snowflake/ml/model/task.py,sha256=Zp5JaLB-YfX5p_HSaw81P3J7UnycQq5EMa87A35VOaQ,286
|
|
132
|
-
snowflake/ml/model/type_hints.py,sha256=
|
|
134
|
+
snowflake/ml/model/type_hints.py,sha256=G0kp85-ksnYoAUHRdXxLFQBLq3XURuqYOpu_YeKEaNA,9847
|
|
133
135
|
snowflake/ml/model/_client/model/model_impl.py,sha256=Yabrbir5vPMOnsVmQJ23YN7vqhi756Jcm6pfO8Aq92o,17469
|
|
134
|
-
snowflake/ml/model/_client/model/model_version_impl.py,sha256=
|
|
136
|
+
snowflake/ml/model/_client/model/model_version_impl.py,sha256=z4jl0aVSXywi3DBtKU4zvvQZpkfsri9ZmULiyeLx3Tc,48132
|
|
135
137
|
snowflake/ml/model/_client/ops/metadata_ops.py,sha256=qpK6PL3OyfuhyOmpvLCpHLy6vCxbZbp1HlEvakFGwv4,4884
|
|
136
138
|
snowflake/ml/model/_client/ops/model_ops.py,sha256=z3T71w9ZNIU5eEA5G59Ous59WzEBs3YBcPO1_zeMI8M,48586
|
|
137
|
-
snowflake/ml/model/_client/ops/service_ops.py,sha256
|
|
139
|
+
snowflake/ml/model/_client/ops/service_ops.py,sha256=Io7Onza0fH3M6vV_HznurBk8HSrlpjJcnpjbvyOm95A,44843
|
|
138
140
|
snowflake/ml/model/_client/service/model_deployment_spec.py,sha256=07b2vdtUEq-a6BZOqX5sqGhK8SI-L1597IgBgoX0XW0,17505
|
|
139
141
|
snowflake/ml/model/_client/service/model_deployment_spec_schema.py,sha256=esuS0MsBEzEyh8ifJH4JOUkSsdX_KL1-KT8KswEHltg,1945
|
|
140
142
|
snowflake/ml/model/_client/sql/_base.py,sha256=Qrm8M92g3MHb-QnSLUlbd8iVKCRxLhG_zr5M2qmXwJ8,1473
|
|
@@ -195,7 +197,7 @@ snowflake/ml/model/_signatures/pytorch_handler.py,sha256=Xy-ITCCX_EgHcyIIqeYSDUI
|
|
|
195
197
|
snowflake/ml/model/_signatures/snowpark_handler.py,sha256=YOBC_Wx-H8bQ967A47nYgqcqLjEA15FbZK69TyAEgvU,7590
|
|
196
198
|
snowflake/ml/model/_signatures/tensorflow_handler.py,sha256=_yrvMg-w_jJoYuyrGXKPX4Dv7Vt8z1e6xIKiWGuZcc4,5660
|
|
197
199
|
snowflake/ml/model/_signatures/utils.py,sha256=vIs12OF_UKH7qrY0JATU-yZhLTgaKt1MJoEemRULA20,17275
|
|
198
|
-
snowflake/ml/model/models/huggingface_pipeline.py,sha256=
|
|
200
|
+
snowflake/ml/model/models/huggingface_pipeline.py,sha256=XNR3MhQpY8Mf7uLOao3r0_oDgPft_-0GStPyvSND-a0,19541
|
|
199
201
|
snowflake/ml/modeling/_internal/estimator_utils.py,sha256=oGi5qbZeV-1cM1Pl-rZLBvcr3YRoUzN_te_l-18apLI,11993
|
|
200
202
|
snowflake/ml/modeling/_internal/model_specifications.py,sha256=3wFMcKPCSoiEzU7Mx6RVem89BRlBBENpX__-Rd7GwdU,4851
|
|
201
203
|
snowflake/ml/modeling/_internal/model_trainer.py,sha256=5Ck1lbdyzcd-TpzAxEyovIN9fjaaVIqugyMHXt0wzH0,971
|
|
@@ -422,15 +424,16 @@ snowflake/ml/monitoring/_client/queries/rmse.ssql,sha256=OEJiSStRz9-qKoZaFvmubtY
|
|
|
422
424
|
snowflake/ml/monitoring/_manager/model_monitor_manager.py,sha256=0jpT1-aRU2tsxSM87I-C2kfJeLevCgM-a-OwU_-VUdI,10302
|
|
423
425
|
snowflake/ml/monitoring/entities/model_monitor_config.py,sha256=1W6TFTPicC6YAbjD7A0w8WMhWireyUxyuEy0RQXmqyY,1787
|
|
424
426
|
snowflake/ml/registry/__init__.py,sha256=XdPQK9ejYkSJVrSQ7HD3jKQO0hKq2mC4bPCB6qrtH3U,76
|
|
425
|
-
snowflake/ml/registry/registry.py,sha256=
|
|
426
|
-
snowflake/ml/registry/_manager/model_manager.py,sha256=
|
|
427
|
+
snowflake/ml/registry/registry.py,sha256=Ro7flVHv3FnEU9Ly3zWRnDAqWiwRSOA2uw_MSKmCBTI,32936
|
|
428
|
+
snowflake/ml/registry/_manager/model_manager.py,sha256=0yhu1QAuwK3hucNSwyJWIb7_TFL7w2Hj-fW4hpIueW4,19207
|
|
429
|
+
snowflake/ml/registry/_manager/model_parameter_reconciler.py,sha256=uWM5gD_NMHjVn4XUSdGsl0YpN62okZ2HK56xX8CQFuc,4516
|
|
427
430
|
snowflake/ml/utils/authentication.py,sha256=E1at4TIAQRDZDsMXSbrKvSJaT6_kSYJBkkr37vU9P2s,2606
|
|
428
431
|
snowflake/ml/utils/connection_params.py,sha256=JuadbzKlgDZLZ5vJ9cnyAiSitvZT9jGSfSSNjIY9P1Q,8282
|
|
429
432
|
snowflake/ml/utils/html_utils.py,sha256=L4pzpvFd20SIk4rie2kTAtcQjbxBHfjKmxonMAT2OoA,7665
|
|
430
433
|
snowflake/ml/utils/sparse.py,sha256=zLBNh-ynhGpKH5TFtopk0YLkHGvv0yq1q-sV59YQKgg,3819
|
|
431
434
|
snowflake/ml/utils/sql_client.py,sha256=pSe2od6Pkh-8NwG3D-xqN76_uNf-ohOtVbT55HeQg1Y,668
|
|
432
|
-
snowflake_ml_python-1.
|
|
433
|
-
snowflake_ml_python-1.
|
|
434
|
-
snowflake_ml_python-1.
|
|
435
|
-
snowflake_ml_python-1.
|
|
436
|
-
snowflake_ml_python-1.
|
|
435
|
+
snowflake_ml_python-1.10.0.dist-info/licenses/LICENSE.txt,sha256=PdEp56Av5m3_kl21iFkVTX_EbHJKFGEdmYeIO1pL_Yk,11365
|
|
436
|
+
snowflake_ml_python-1.10.0.dist-info/METADATA,sha256=G5N7mWVXwReXU_w3-74NYAUtABrIpQpjDODSx6fCCWM,90765
|
|
437
|
+
snowflake_ml_python-1.10.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
438
|
+
snowflake_ml_python-1.10.0.dist-info/top_level.txt,sha256=TY0gFSHKDdZy3THb0FGomyikWQasEGldIR1O0HGOHVw,10
|
|
439
|
+
snowflake_ml_python-1.10.0.dist-info/RECORD,,
|
|
@@ -1,121 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
3
|
-
from warnings import warn
|
|
4
|
-
|
|
5
|
-
import lightgbm as lgb
|
|
6
|
-
import xgboost as xgb
|
|
7
|
-
|
|
8
|
-
from snowflake.ml.model.model_signature import ModelSignature
|
|
9
|
-
|
|
10
|
-
if TYPE_CHECKING:
|
|
11
|
-
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
experiment_tracking: "ExperimentTracking",
|
|
18
|
-
log_model: bool = True,
|
|
19
|
-
log_metrics: bool = True,
|
|
20
|
-
log_params: bool = True,
|
|
21
|
-
model_name: Optional[str] = None,
|
|
22
|
-
model_signature: Optional[ModelSignature] = None,
|
|
23
|
-
) -> None:
|
|
24
|
-
self._experiment_tracking = experiment_tracking
|
|
25
|
-
self.log_model = log_model
|
|
26
|
-
self.log_metrics = log_metrics
|
|
27
|
-
self.log_params = log_params
|
|
28
|
-
self.model_name = model_name
|
|
29
|
-
self.model_signature = model_signature
|
|
30
|
-
|
|
31
|
-
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
32
|
-
def _flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str = "") -> dict[str, Any]:
|
|
33
|
-
flat_params = {}
|
|
34
|
-
items = params.items() if isinstance(params, dict) else enumerate(params)
|
|
35
|
-
for key, value in items:
|
|
36
|
-
new_prefix = f"{prefix}.{key}" if prefix else str(key)
|
|
37
|
-
if isinstance(value, (dict, list)):
|
|
38
|
-
flat_params.update(_flatten_nested_params(value, new_prefix))
|
|
39
|
-
else:
|
|
40
|
-
flat_params[new_prefix] = value
|
|
41
|
-
return flat_params
|
|
42
|
-
|
|
43
|
-
if self.log_params:
|
|
44
|
-
params = json.loads(model.save_config())
|
|
45
|
-
self._experiment_tracking.log_params(_flatten_nested_params(params))
|
|
46
|
-
|
|
47
|
-
return model
|
|
48
|
-
|
|
49
|
-
def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool:
|
|
50
|
-
if self.log_metrics:
|
|
51
|
-
for dataset_name, metrics in evals_log.items():
|
|
52
|
-
for metric_name, log in metrics.items():
|
|
53
|
-
metric_key = dataset_name + ":" + metric_name
|
|
54
|
-
self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=epoch)
|
|
55
|
-
|
|
56
|
-
return False
|
|
57
|
-
|
|
58
|
-
def after_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
59
|
-
if self.log_model:
|
|
60
|
-
if not self.model_signature:
|
|
61
|
-
warn(
|
|
62
|
-
"Model will not be logged because model signature is missing. "
|
|
63
|
-
"To autolog the model, please specify `model_signature` when constructing SnowflakeXgboostCallback."
|
|
64
|
-
)
|
|
65
|
-
return model
|
|
66
|
-
|
|
67
|
-
model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
|
|
68
|
-
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
69
|
-
model=model,
|
|
70
|
-
model_name=model_name,
|
|
71
|
-
signatures={"predict": self.model_signature},
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
return model
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
78
|
-
def __init__(
|
|
79
|
-
self,
|
|
80
|
-
experiment_tracking: "ExperimentTracking",
|
|
81
|
-
log_model: bool = True,
|
|
82
|
-
log_metrics: bool = True,
|
|
83
|
-
log_params: bool = True,
|
|
84
|
-
model_name: Optional[str] = None,
|
|
85
|
-
model_signature: Optional[ModelSignature] = None,
|
|
86
|
-
) -> None:
|
|
87
|
-
self._experiment_tracking = experiment_tracking
|
|
88
|
-
self.log_model = log_model
|
|
89
|
-
self.log_metrics = log_metrics
|
|
90
|
-
self.log_params = log_params
|
|
91
|
-
self.model_name = model_name
|
|
92
|
-
self.model_signature = model_signature
|
|
93
|
-
|
|
94
|
-
super().__init__(eval_result={})
|
|
95
|
-
|
|
96
|
-
def __call__(self, env: lgb.callback.CallbackEnv) -> None:
|
|
97
|
-
if self.log_params:
|
|
98
|
-
if env.iteration == env.begin_iteration: # Log params only at the first iteration
|
|
99
|
-
self._experiment_tracking.log_params(env.params)
|
|
100
|
-
|
|
101
|
-
if self.log_metrics:
|
|
102
|
-
super().__call__(env)
|
|
103
|
-
for dataset_name, metrics in self.eval_result.items():
|
|
104
|
-
for metric_name, log in metrics.items():
|
|
105
|
-
metric_key = dataset_name + ":" + metric_name
|
|
106
|
-
self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=env.iteration)
|
|
107
|
-
|
|
108
|
-
if self.log_model:
|
|
109
|
-
if env.iteration == env.end_iteration - 1: # Log model only at the last iteration
|
|
110
|
-
if self.model_signature:
|
|
111
|
-
model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
|
|
112
|
-
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
113
|
-
model=env.model,
|
|
114
|
-
model_name=model_name,
|
|
115
|
-
signatures={"predict": self.model_signature},
|
|
116
|
-
)
|
|
117
|
-
else:
|
|
118
|
-
warn(
|
|
119
|
-
"Model will not be logged because model signature is missing. To autolog the model, "
|
|
120
|
-
"please specify `model_signature` when constructing SnowflakeLightgbmCallback."
|
|
121
|
-
)
|
|
File without changes
|
{snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt
RENAMED
|
File without changes
|
|
File without changes
|