snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +54 -42
- snowflake/ml/_internal/utils/service_logger.py +105 -3
- snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
- snowflake/ml/data/data_connector.py +13 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +2 -1
- snowflake/ml/dataset/dataset_reader.py +14 -4
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/callback.py +121 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +156 -54
- snowflake/ml/jobs/_utils/query_helper.py +16 -5
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
- snowflake/ml/jobs/_utils/spec_utils.py +23 -8
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +70 -75
- snowflake/ml/jobs/manager.py +59 -31
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/service_ops.py +336 -137
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +11 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +32 -15
- snowflake/ml/registry/registry.py +48 -80
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
|
5
|
+
from snowflake.ml.utils import sql_client
|
|
6
|
+
from snowflake.snowpark import row, session
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
10
|
+
|
|
11
|
+
RUN_NAME_COL_NAME = "name"
|
|
12
|
+
RUN_METADATA_COL_NAME = "metadata"
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
session: session.Session,
|
|
17
|
+
*,
|
|
18
|
+
database_name: sql_identifier.SqlIdentifier,
|
|
19
|
+
schema_name: sql_identifier.SqlIdentifier,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Snowflake SQL Client to manage experiment tracking.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
session: Active snowpark session.
|
|
25
|
+
database_name: Name of the Database where experiment tracking resources are provisioned.
|
|
26
|
+
schema_name: Name of the Schema where experiment tracking resources are provisioned.
|
|
27
|
+
"""
|
|
28
|
+
super().__init__(session, database_name=database_name, schema_name=schema_name)
|
|
29
|
+
|
|
30
|
+
def create_experiment(
|
|
31
|
+
self,
|
|
32
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
33
|
+
creation_mode: sql_client.CreationMode,
|
|
34
|
+
) -> None:
|
|
35
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
36
|
+
if_not_exists_sql = "IF NOT EXISTS" if creation_mode.if_not_exists else ""
|
|
37
|
+
query_result_checker.SqlResultValidator(
|
|
38
|
+
self._session, f"CREATE EXPERIMENT {if_not_exists_sql} {experiment_fqn}"
|
|
39
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
40
|
+
|
|
41
|
+
def drop_experiment(self, *, experiment_name: sql_identifier.SqlIdentifier) -> None:
|
|
42
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
43
|
+
query_result_checker.SqlResultValidator(self._session, f"DROP EXPERIMENT {experiment_fqn}").has_dimensions(
|
|
44
|
+
expected_rows=1, expected_cols=1
|
|
45
|
+
).validate()
|
|
46
|
+
|
|
47
|
+
def add_run(
|
|
48
|
+
self,
|
|
49
|
+
*,
|
|
50
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
51
|
+
run_name: sql_identifier.SqlIdentifier,
|
|
52
|
+
live: bool = True,
|
|
53
|
+
) -> None:
|
|
54
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
55
|
+
query_result_checker.SqlResultValidator(
|
|
56
|
+
self._session, f"ALTER EXPERIMENT {experiment_fqn} ADD {'LIVE' if live else ''} RUN {run_name}"
|
|
57
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
58
|
+
|
|
59
|
+
def commit_run(
|
|
60
|
+
self,
|
|
61
|
+
*,
|
|
62
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
63
|
+
run_name: sql_identifier.SqlIdentifier,
|
|
64
|
+
) -> None:
|
|
65
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
66
|
+
query_result_checker.SqlResultValidator(
|
|
67
|
+
self._session, f"ALTER EXPERIMENT {experiment_fqn} COMMIT RUN {run_name}"
|
|
68
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
69
|
+
|
|
70
|
+
def drop_run(
|
|
71
|
+
self, *, experiment_name: sql_identifier.SqlIdentifier, run_name: sql_identifier.SqlIdentifier
|
|
72
|
+
) -> None:
|
|
73
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
74
|
+
query_result_checker.SqlResultValidator(
|
|
75
|
+
self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
|
|
76
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
77
|
+
|
|
78
|
+
def modify_run(
|
|
79
|
+
self,
|
|
80
|
+
*,
|
|
81
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
82
|
+
run_name: sql_identifier.SqlIdentifier,
|
|
83
|
+
run_metadata: str,
|
|
84
|
+
) -> None:
|
|
85
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
86
|
+
query_result_checker.SqlResultValidator(
|
|
87
|
+
self._session,
|
|
88
|
+
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} SET METADATA=$${run_metadata}$$",
|
|
89
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
90
|
+
|
|
91
|
+
def show_runs_in_experiment(
|
|
92
|
+
self, *, experiment_name: sql_identifier.SqlIdentifier, like: Optional[str] = None
|
|
93
|
+
) -> list[row.Row]:
|
|
94
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
95
|
+
like_clause = f"LIKE '{like}'" if like else ""
|
|
96
|
+
return query_result_checker.SqlResultValidator(
|
|
97
|
+
self._session, f"SHOW RUNS {like_clause} IN EXPERIMENT {experiment_fqn}"
|
|
98
|
+
).validate()
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import types
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
|
|
5
|
+
from snowflake.ml._internal.utils import sql_identifier
|
|
6
|
+
from snowflake.ml.experiment import _experiment_info as experiment_info
|
|
7
|
+
from snowflake.ml.experiment._client import experiment_tracking_sql_client
|
|
8
|
+
from snowflake.ml.experiment._entities import run_metadata
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from snowflake.ml.experiment import experiment_tracking
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Run:
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
experiment_tracking: "experiment_tracking.ExperimentTracking",
|
|
18
|
+
*,
|
|
19
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
20
|
+
run_name: sql_identifier.SqlIdentifier,
|
|
21
|
+
) -> None:
|
|
22
|
+
self._experiment_tracking = experiment_tracking
|
|
23
|
+
self.experiment_name = experiment_name
|
|
24
|
+
self.name = run_name
|
|
25
|
+
|
|
26
|
+
self._patcher = experiment_info.ExperimentInfoPatcher(
|
|
27
|
+
experiment_info=self._get_experiment_info(),
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def __enter__(self) -> "Run":
|
|
31
|
+
self._patcher.__enter__()
|
|
32
|
+
return self
|
|
33
|
+
|
|
34
|
+
def __exit__(
|
|
35
|
+
self,
|
|
36
|
+
exc_type: Optional[type[BaseException]],
|
|
37
|
+
exc_value: Optional[BaseException],
|
|
38
|
+
traceback: Optional[types.TracebackType],
|
|
39
|
+
) -> None:
|
|
40
|
+
self._patcher.__exit__(exc_type, exc_value, traceback)
|
|
41
|
+
if self._experiment_tracking._run is self:
|
|
42
|
+
self._experiment_tracking.end_run()
|
|
43
|
+
|
|
44
|
+
def _get_metadata(
|
|
45
|
+
self,
|
|
46
|
+
) -> run_metadata.RunMetadata:
|
|
47
|
+
runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
|
|
48
|
+
experiment_name=self.experiment_name, like=str(self.name)
|
|
49
|
+
)
|
|
50
|
+
if not runs:
|
|
51
|
+
raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
|
|
52
|
+
return run_metadata.RunMetadata.from_dict(
|
|
53
|
+
json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
|
|
57
|
+
return experiment_info.ExperimentInfo(
|
|
58
|
+
fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
|
|
59
|
+
self._experiment_tracking._database_name, self._experiment_tracking._schema_name, self.experiment_name
|
|
60
|
+
),
|
|
61
|
+
run_name=self.name.identifier(),
|
|
62
|
+
)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import enum
|
|
3
|
+
import typing
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RunStatus(str, enum.Enum):
|
|
7
|
+
UNKNOWN = "UNKNOWN"
|
|
8
|
+
RUNNING = "RUNNING"
|
|
9
|
+
FINISHED = "FINISHED"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclasses.dataclass
|
|
13
|
+
class Metric:
|
|
14
|
+
name: str
|
|
15
|
+
value: float
|
|
16
|
+
step: int
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclasses.dataclass
|
|
20
|
+
class Param:
|
|
21
|
+
name: str
|
|
22
|
+
value: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclasses.dataclass
|
|
26
|
+
class RunMetadata:
|
|
27
|
+
status: RunStatus
|
|
28
|
+
metrics: list[Metric]
|
|
29
|
+
parameters: list[Param]
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_dict(
|
|
33
|
+
cls,
|
|
34
|
+
metadata: dict, # type: ignore[type-arg]
|
|
35
|
+
) -> "RunMetadata":
|
|
36
|
+
return RunMetadata(
|
|
37
|
+
status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
|
|
38
|
+
metrics=[Metric(**m) for m in metadata.get("metrics", [])],
|
|
39
|
+
parameters=[Param(**p) for p in metadata.get("parameters", [])],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def to_dict(self) -> dict: # type: ignore[type-arg]
|
|
43
|
+
return dataclasses.asdict(self)
|
|
44
|
+
|
|
45
|
+
def set_metric(
|
|
46
|
+
self,
|
|
47
|
+
key: str,
|
|
48
|
+
value: float,
|
|
49
|
+
step: int,
|
|
50
|
+
) -> None:
|
|
51
|
+
for metric in self.metrics:
|
|
52
|
+
if metric.name == key and metric.step == step:
|
|
53
|
+
metric.value = value
|
|
54
|
+
break
|
|
55
|
+
else:
|
|
56
|
+
self.metrics.append(Metric(name=key, value=value, step=step))
|
|
57
|
+
|
|
58
|
+
def set_param(
|
|
59
|
+
self,
|
|
60
|
+
key: str,
|
|
61
|
+
value: typing.Any,
|
|
62
|
+
) -> None:
|
|
63
|
+
for parameter in self.parameters:
|
|
64
|
+
if parameter.name == key:
|
|
65
|
+
parameter.value = str(value)
|
|
66
|
+
break
|
|
67
|
+
else:
|
|
68
|
+
self.parameters.append(Param(name=key, value=str(value)))
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
import functools
|
|
3
|
+
import types
|
|
4
|
+
from typing import Callable, Optional
|
|
5
|
+
|
|
6
|
+
from snowflake.ml import model
|
|
7
|
+
from snowflake.ml.registry._manager import model_manager
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclasses.dataclass(frozen=True)
|
|
11
|
+
class ExperimentInfo:
|
|
12
|
+
"""Serializable information identifying a Experiment"""
|
|
13
|
+
|
|
14
|
+
fully_qualified_name: str
|
|
15
|
+
run_name: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ExperimentInfoPatcher:
|
|
19
|
+
"""Context manager that patches ModelManager.log_model to include experiment information.
|
|
20
|
+
|
|
21
|
+
This class maintains a stack of active experiment contexts and ensures that
|
|
22
|
+
log_model calls are automatically tagged with the appropriate experiment info.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# Store original method at class definition time to avoid recursive patching
|
|
26
|
+
_original_log_model: Callable[..., model.ModelVersion] = model_manager.ModelManager.log_model
|
|
27
|
+
|
|
28
|
+
# Stack of active experiment_info contexts for nested experiment support
|
|
29
|
+
_experiment_info_stack: list[ExperimentInfo] = []
|
|
30
|
+
|
|
31
|
+
def __init__(self, experiment_info: ExperimentInfo) -> None:
|
|
32
|
+
self._experiment_info = experiment_info
|
|
33
|
+
|
|
34
|
+
def __enter__(self) -> "ExperimentInfoPatcher":
|
|
35
|
+
# Only patch ModelManager.log_model if we're the first patcher to avoid nested patching
|
|
36
|
+
if not ExperimentInfoPatcher._experiment_info_stack:
|
|
37
|
+
|
|
38
|
+
@functools.wraps(ExperimentInfoPatcher._original_log_model)
|
|
39
|
+
def patched(*args, **kwargs) -> model.ModelVersion: # type: ignore[no-untyped-def]
|
|
40
|
+
# Use the most recent (top of stack) experiment_info for nested contexts
|
|
41
|
+
current_experiment_info = ExperimentInfoPatcher._experiment_info_stack[-1]
|
|
42
|
+
return ExperimentInfoPatcher._original_log_model(
|
|
43
|
+
*args, **kwargs, experiment_info=current_experiment_info
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
model_manager.ModelManager.log_model = patched # type: ignore[method-assign]
|
|
47
|
+
|
|
48
|
+
ExperimentInfoPatcher._experiment_info_stack.append(self._experiment_info)
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __exit__(
|
|
52
|
+
self,
|
|
53
|
+
exc_type: Optional[type[BaseException]],
|
|
54
|
+
exc_value: Optional[BaseException],
|
|
55
|
+
traceback: Optional[types.TracebackType],
|
|
56
|
+
) -> None:
|
|
57
|
+
ExperimentInfoPatcher._experiment_info_stack.pop()
|
|
58
|
+
|
|
59
|
+
# Restore original method when no patches are active to clean up properly
|
|
60
|
+
if not ExperimentInfoPatcher._experiment_info_stack:
|
|
61
|
+
model_manager.ModelManager.log_model = ( # type: ignore[method-assign]
|
|
62
|
+
ExperimentInfoPatcher._original_log_model
|
|
63
|
+
)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
3
|
+
from warnings import warn
|
|
4
|
+
|
|
5
|
+
import lightgbm as lgb
|
|
6
|
+
import xgboost as xgb
|
|
7
|
+
|
|
8
|
+
from snowflake.ml.model.model_signature import ModelSignature
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from snowflake.ml.experiment.experiment_tracking import ExperimentTracking
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SnowflakeXgboostCallback(xgb.callback.TrainingCallback):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
experiment_tracking: "ExperimentTracking",
|
|
18
|
+
log_model: bool = True,
|
|
19
|
+
log_metrics: bool = True,
|
|
20
|
+
log_params: bool = True,
|
|
21
|
+
model_name: Optional[str] = None,
|
|
22
|
+
model_signature: Optional[ModelSignature] = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
self._experiment_tracking = experiment_tracking
|
|
25
|
+
self.log_model = log_model
|
|
26
|
+
self.log_metrics = log_metrics
|
|
27
|
+
self.log_params = log_params
|
|
28
|
+
self.model_name = model_name
|
|
29
|
+
self.model_signature = model_signature
|
|
30
|
+
|
|
31
|
+
def before_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
32
|
+
def _flatten_nested_params(params: Union[list[Any], dict[str, Any]], prefix: str = "") -> dict[str, Any]:
|
|
33
|
+
flat_params = {}
|
|
34
|
+
items = params.items() if isinstance(params, dict) else enumerate(params)
|
|
35
|
+
for key, value in items:
|
|
36
|
+
new_prefix = f"{prefix}.{key}" if prefix else str(key)
|
|
37
|
+
if isinstance(value, (dict, list)):
|
|
38
|
+
flat_params.update(_flatten_nested_params(value, new_prefix))
|
|
39
|
+
else:
|
|
40
|
+
flat_params[new_prefix] = value
|
|
41
|
+
return flat_params
|
|
42
|
+
|
|
43
|
+
if self.log_params:
|
|
44
|
+
params = json.loads(model.save_config())
|
|
45
|
+
self._experiment_tracking.log_params(_flatten_nested_params(params))
|
|
46
|
+
|
|
47
|
+
return model
|
|
48
|
+
|
|
49
|
+
def after_iteration(self, model: Any, epoch: int, evals_log: dict[str, dict[str, Any]]) -> bool:
|
|
50
|
+
if self.log_metrics:
|
|
51
|
+
for dataset_name, metrics in evals_log.items():
|
|
52
|
+
for metric_name, log in metrics.items():
|
|
53
|
+
metric_key = dataset_name + ":" + metric_name
|
|
54
|
+
self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=epoch)
|
|
55
|
+
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
def after_training(self, model: xgb.Booster) -> xgb.Booster:
|
|
59
|
+
if self.log_model:
|
|
60
|
+
if not self.model_signature:
|
|
61
|
+
warn(
|
|
62
|
+
"Model will not be logged because model signature is missing. "
|
|
63
|
+
"To autolog the model, please specify `model_signature` when constructing SnowflakeXgboostCallback."
|
|
64
|
+
)
|
|
65
|
+
return model
|
|
66
|
+
|
|
67
|
+
model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
|
|
68
|
+
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
69
|
+
model=model,
|
|
70
|
+
model_name=model_name,
|
|
71
|
+
signatures={"predict": self.model_signature},
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return model
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class SnowflakeLightgbmCallback(lgb.callback._RecordEvaluationCallback):
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
experiment_tracking: "ExperimentTracking",
|
|
81
|
+
log_model: bool = True,
|
|
82
|
+
log_metrics: bool = True,
|
|
83
|
+
log_params: bool = True,
|
|
84
|
+
model_name: Optional[str] = None,
|
|
85
|
+
model_signature: Optional[ModelSignature] = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
self._experiment_tracking = experiment_tracking
|
|
88
|
+
self.log_model = log_model
|
|
89
|
+
self.log_metrics = log_metrics
|
|
90
|
+
self.log_params = log_params
|
|
91
|
+
self.model_name = model_name
|
|
92
|
+
self.model_signature = model_signature
|
|
93
|
+
|
|
94
|
+
super().__init__(eval_result={})
|
|
95
|
+
|
|
96
|
+
def __call__(self, env: lgb.callback.CallbackEnv) -> None:
|
|
97
|
+
if self.log_params:
|
|
98
|
+
if env.iteration == env.begin_iteration: # Log params only at the first iteration
|
|
99
|
+
self._experiment_tracking.log_params(env.params)
|
|
100
|
+
|
|
101
|
+
if self.log_metrics:
|
|
102
|
+
super().__call__(env)
|
|
103
|
+
for dataset_name, metrics in self.eval_result.items():
|
|
104
|
+
for metric_name, log in metrics.items():
|
|
105
|
+
metric_key = dataset_name + ":" + metric_name
|
|
106
|
+
self._experiment_tracking.log_metric(key=metric_key, value=log[-1], step=env.iteration)
|
|
107
|
+
|
|
108
|
+
if self.log_model:
|
|
109
|
+
if env.iteration == env.end_iteration - 1: # Log model only at the last iteration
|
|
110
|
+
if self.model_signature:
|
|
111
|
+
model_name = self.model_name or self._experiment_tracking._get_or_set_experiment().name + "_model"
|
|
112
|
+
self._experiment_tracking.log_model( # type: ignore[call-arg]
|
|
113
|
+
model=env.model,
|
|
114
|
+
model_name=model_name,
|
|
115
|
+
signatures={"predict": self.model_signature},
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
warn(
|
|
119
|
+
"Model will not be logged because model signature is missing. To autolog the model, "
|
|
120
|
+
"please specify `model_signature` when constructing SnowflakeLightgbmCallback."
|
|
121
|
+
)
|