nextrec 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +30 -15
- nextrec/basic/features.py +1 -0
- nextrec/basic/layers.py +6 -8
- nextrec/basic/loggers.py +14 -7
- nextrec/basic/metrics.py +6 -76
- nextrec/basic/model.py +312 -318
- nextrec/cli.py +5 -10
- nextrec/data/__init__.py +13 -16
- nextrec/data/batch_utils.py +3 -2
- nextrec/data/data_processing.py +10 -2
- nextrec/data/data_utils.py +9 -14
- nextrec/data/dataloader.py +12 -13
- nextrec/data/preprocessor.py +328 -255
- nextrec/loss/__init__.py +1 -5
- nextrec/loss/loss_utils.py +2 -8
- nextrec/models/generative/__init__.py +1 -8
- nextrec/models/generative/hstu.py +6 -4
- nextrec/models/multi_task/esmm.py +2 -2
- nextrec/models/multi_task/mmoe.py +2 -2
- nextrec/models/multi_task/ple.py +2 -2
- nextrec/models/multi_task/poso.py +2 -3
- nextrec/models/multi_task/share_bottom.py +2 -2
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +2 -2
- nextrec/models/ranking/dcn_v2.py +2 -2
- nextrec/models/ranking/deepfm.py +2 -2
- nextrec/models/ranking/dien.py +3 -3
- nextrec/models/ranking/din.py +3 -3
- nextrec/models/ranking/ffm.py +0 -0
- nextrec/models/ranking/fibinet.py +5 -5
- nextrec/models/ranking/fm.py +3 -7
- nextrec/models/ranking/lr.py +0 -0
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +2 -2
- nextrec/models/ranking/widedeep.py +2 -2
- nextrec/models/ranking/xdeepfm.py +2 -2
- nextrec/models/representation/__init__.py +9 -0
- nextrec/models/{generative → representation}/rqvae.py +9 -9
- nextrec/models/retrieval/__init__.py +0 -0
- nextrec/models/{match → retrieval}/dssm.py +8 -3
- nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
- nextrec/models/{match → retrieval}/mind.py +4 -3
- nextrec/models/{match → retrieval}/sdm.py +4 -3
- nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
- nextrec/utils/__init__.py +60 -46
- nextrec/utils/config.py +8 -7
- nextrec/utils/console.py +371 -0
- nextrec/utils/{synthetic_data.py → data.py} +102 -15
- nextrec/utils/feature.py +15 -0
- nextrec/utils/torch_utils.py +411 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/METADATA +6 -6
- nextrec-0.4.9.dist-info/RECORD +70 -0
- nextrec/utils/cli_utils.py +0 -58
- nextrec/utils/device.py +0 -78
- nextrec/utils/distributed.py +0 -141
- nextrec/utils/file.py +0 -92
- nextrec/utils/initializer.py +0 -79
- nextrec/utils/optimizer.py +0 -75
- nextrec/utils/tensor.py +0 -72
- nextrec-0.4.8.dist-info/RECORD +0 -71
- /nextrec/models/{match/__init__.py → ranking/eulernet.py} +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/WHEEL +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/licenses/LICENSE +0 -0
nextrec/utils/__init__.py
CHANGED
|
@@ -1,71 +1,84 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Utilities package for NextRec
|
|
3
3
|
|
|
4
|
-
This package provides various utility functions organized by category:
|
|
5
|
-
- optimizer: Optimizer and scheduler utilities
|
|
6
|
-
- initializer: Weight initialization utilities
|
|
7
|
-
- embedding: Embedding dimension calculation
|
|
8
|
-
- device_utils: Device management and selection
|
|
9
|
-
- tensor_utils: Tensor operations and conversions
|
|
10
|
-
- file_utils: File I/O operations
|
|
11
|
-
- model_utils: Model-related utilities
|
|
12
|
-
- feature_utils: Feature processing utilities
|
|
13
|
-
- config_utils: Configuration loading and processing utilities
|
|
14
|
-
|
|
15
4
|
Date: create on 13/11/2025
|
|
16
|
-
Last update:
|
|
5
|
+
Last update: 19/12/2025
|
|
17
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
18
7
|
"""
|
|
19
8
|
|
|
20
|
-
from . import
|
|
21
|
-
from .optimizer import get_optimizer, get_scheduler
|
|
22
|
-
from .initializer import get_initializer
|
|
23
|
-
from .embedding import get_auto_embedding_dim
|
|
24
|
-
from .device import resolve_device, get_device_info
|
|
25
|
-
from .tensor import to_tensor, stack_tensors, concat_tensors, pad_sequence_tensors
|
|
26
|
-
from .file import (
|
|
27
|
-
resolve_file_paths,
|
|
28
|
-
read_table,
|
|
29
|
-
load_dataframes,
|
|
30
|
-
iter_file_chunks,
|
|
31
|
-
default_output_dir,
|
|
32
|
-
read_yaml,
|
|
33
|
-
)
|
|
34
|
-
from .model import merge_features, get_mlp_output_dim
|
|
35
|
-
from .feature import normalize_to_list
|
|
36
|
-
from .synthetic_data import (
|
|
37
|
-
generate_match_data,
|
|
38
|
-
generate_ranking_data,
|
|
39
|
-
generate_multitask_data,
|
|
40
|
-
generate_distributed_ranking_data,
|
|
41
|
-
)
|
|
9
|
+
from . import console, data, embedding, torch_utils
|
|
42
10
|
from .config import (
|
|
43
|
-
resolve_path,
|
|
44
|
-
select_features,
|
|
45
|
-
register_processor_features,
|
|
46
11
|
build_feature_objects,
|
|
12
|
+
build_model_instance,
|
|
47
13
|
extract_feature_groups,
|
|
48
14
|
load_model_class,
|
|
49
|
-
|
|
15
|
+
register_processor_features,
|
|
16
|
+
resolve_path,
|
|
17
|
+
select_features,
|
|
18
|
+
)
|
|
19
|
+
from .console import (
|
|
20
|
+
display_metrics_table,
|
|
21
|
+
get_nextrec_version,
|
|
22
|
+
log_startup_info,
|
|
23
|
+
progress,
|
|
24
|
+
)
|
|
25
|
+
from .data import (
|
|
26
|
+
default_output_dir,
|
|
27
|
+
generate_distributed_ranking_data,
|
|
28
|
+
generate_match_data,
|
|
29
|
+
generate_multitask_data,
|
|
30
|
+
generate_ranking_data,
|
|
31
|
+
iter_file_chunks,
|
|
32
|
+
load_dataframes,
|
|
33
|
+
read_table,
|
|
34
|
+
read_yaml,
|
|
35
|
+
resolve_file_paths,
|
|
36
|
+
)
|
|
37
|
+
from .embedding import get_auto_embedding_dim
|
|
38
|
+
from .feature import normalize_to_list
|
|
39
|
+
from .model import get_mlp_output_dim, merge_features
|
|
40
|
+
from .torch_utils import (
|
|
41
|
+
add_distributed_sampler,
|
|
42
|
+
concat_tensors,
|
|
43
|
+
configure_device,
|
|
44
|
+
gather_numpy,
|
|
45
|
+
get_device_info,
|
|
46
|
+
get_initializer,
|
|
47
|
+
get_optimizer,
|
|
48
|
+
get_scheduler,
|
|
49
|
+
init_process_group,
|
|
50
|
+
pad_sequence_tensors,
|
|
51
|
+
resolve_device,
|
|
52
|
+
stack_tensors,
|
|
53
|
+
to_tensor,
|
|
50
54
|
)
|
|
51
55
|
|
|
52
56
|
__all__ = [
|
|
53
|
-
#
|
|
57
|
+
# Console utilities
|
|
58
|
+
"get_nextrec_version",
|
|
59
|
+
"log_startup_info",
|
|
60
|
+
"progress",
|
|
61
|
+
"display_metrics_table",
|
|
62
|
+
# Optimizer & Scheduler (torch utils)
|
|
54
63
|
"get_optimizer",
|
|
55
64
|
"get_scheduler",
|
|
56
|
-
# Initializer
|
|
65
|
+
# Initializer (torch utils)
|
|
57
66
|
"get_initializer",
|
|
58
|
-
# Embedding
|
|
67
|
+
# Embedding utilities
|
|
59
68
|
"get_auto_embedding_dim",
|
|
60
|
-
# Device utilities
|
|
69
|
+
# Device utilities (torch utils)
|
|
61
70
|
"resolve_device",
|
|
62
71
|
"get_device_info",
|
|
72
|
+
"configure_device",
|
|
73
|
+
"init_process_group",
|
|
74
|
+
"gather_numpy",
|
|
75
|
+
"add_distributed_sampler",
|
|
63
76
|
# Tensor utilities
|
|
64
77
|
"to_tensor",
|
|
65
78
|
"stack_tensors",
|
|
66
79
|
"concat_tensors",
|
|
67
80
|
"pad_sequence_tensors",
|
|
68
|
-
#
|
|
81
|
+
# Data utilities
|
|
69
82
|
"resolve_file_paths",
|
|
70
83
|
"read_table",
|
|
71
84
|
"read_yaml",
|
|
@@ -79,10 +92,10 @@ __all__ = [
|
|
|
79
92
|
"normalize_to_list",
|
|
80
93
|
# Config utilities
|
|
81
94
|
"resolve_path",
|
|
82
|
-
"select_features",
|
|
83
95
|
"register_processor_features",
|
|
84
96
|
"build_feature_objects",
|
|
85
97
|
"extract_feature_groups",
|
|
98
|
+
"select_features",
|
|
86
99
|
"load_model_class",
|
|
87
100
|
"build_model_instance",
|
|
88
101
|
# Synthetic data utilities
|
|
@@ -91,7 +104,8 @@ __all__ = [
|
|
|
91
104
|
"generate_multitask_data",
|
|
92
105
|
"generate_distributed_ranking_data",
|
|
93
106
|
# Module exports
|
|
94
|
-
"
|
|
95
|
-
"
|
|
107
|
+
"console",
|
|
108
|
+
"data",
|
|
96
109
|
"embedding",
|
|
110
|
+
"torch_utils",
|
|
97
111
|
]
|
nextrec/utils/config.py
CHANGED
|
@@ -4,7 +4,8 @@ Configuration utilities for NextRec
|
|
|
4
4
|
This module provides utilities for loading and processing configuration files,
|
|
5
5
|
including feature configuration, model configuration, and training configuration.
|
|
6
6
|
|
|
7
|
-
Date: create on
|
|
7
|
+
Date: create on 27/10/2025
|
|
8
|
+
Checkpoint: edit on 19/12/2025
|
|
8
9
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
9
10
|
"""
|
|
10
11
|
|
|
@@ -23,7 +24,7 @@ import torch
|
|
|
23
24
|
from nextrec.utils.feature import normalize_to_list
|
|
24
25
|
|
|
25
26
|
if TYPE_CHECKING:
|
|
26
|
-
from nextrec.basic.features import DenseFeature,
|
|
27
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
27
28
|
from nextrec.data.preprocessor import DataProcessor
|
|
28
29
|
|
|
29
30
|
|
|
@@ -52,7 +53,7 @@ def select_features(
|
|
|
52
53
|
names = [name for name in cfg.keys() if name in columns]
|
|
53
54
|
missing = [name for name in cfg.keys() if name not in columns]
|
|
54
55
|
if missing:
|
|
55
|
-
print(f"[
|
|
56
|
+
print(f"[Feature Config] skipped missing {group} columns: {missing}")
|
|
56
57
|
return names
|
|
57
58
|
|
|
58
59
|
dense_names = pick("dense")
|
|
@@ -129,7 +130,7 @@ def build_feature_objects(
|
|
|
129
130
|
sparse_names: List of sparse feature names
|
|
130
131
|
sequence_names: List of sequence feature names
|
|
131
132
|
"""
|
|
132
|
-
from nextrec.basic.features import DenseFeature,
|
|
133
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
133
134
|
|
|
134
135
|
dense_cfg = feature_cfg.get("dense", {}) or {}
|
|
135
136
|
sparse_cfg = feature_cfg.get("sparse", {}) or {}
|
|
@@ -236,7 +237,7 @@ def extract_feature_groups(
|
|
|
236
237
|
|
|
237
238
|
if missing_defined:
|
|
238
239
|
print(
|
|
239
|
-
f"[
|
|
240
|
+
f"[Feature Config] feature_groups.{group_name} contains features not defined in dense/sparse/sequence: {missing_defined}"
|
|
240
241
|
)
|
|
241
242
|
|
|
242
243
|
for n in name_list:
|
|
@@ -249,7 +250,7 @@ def extract_feature_groups(
|
|
|
249
250
|
|
|
250
251
|
if missing_cols:
|
|
251
252
|
print(
|
|
252
|
-
f"[
|
|
253
|
+
f"[Feature Config] feature_groups.{group_name} missing data columns: {missing_cols}"
|
|
253
254
|
)
|
|
254
255
|
|
|
255
256
|
resolved[group_name] = filtered
|
|
@@ -442,7 +443,7 @@ def build_model_instance(
|
|
|
442
443
|
|
|
443
444
|
if group_key not in feature_groups:
|
|
444
445
|
print(
|
|
445
|
-
f"[
|
|
446
|
+
f"[Feature Config] feature_bindings refers to unknown group '{group_key}', skipped"
|
|
446
447
|
)
|
|
447
448
|
continue
|
|
448
449
|
|
nextrec/utils/console.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Console and CLI utilities for NextRec.
|
|
3
|
+
|
|
4
|
+
This module centralizes CLI logging helpers, progress display, and metric tables.
|
|
5
|
+
|
|
6
|
+
Date: create on 19/12/2025
|
|
7
|
+
Checkpoint: edit on 19/12/2025
|
|
8
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import io
|
|
14
|
+
import logging
|
|
15
|
+
import numbers
|
|
16
|
+
import os
|
|
17
|
+
import platform
|
|
18
|
+
import sys
|
|
19
|
+
from datetime import datetime, timedelta
|
|
20
|
+
from typing import Any, Callable, Iterable, Mapping, TypeVar
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
from rich import box
|
|
24
|
+
from rich.console import Console
|
|
25
|
+
from rich.progress import (
|
|
26
|
+
BarColumn,
|
|
27
|
+
MofNCompleteColumn,
|
|
28
|
+
Progress,
|
|
29
|
+
SpinnerColumn,
|
|
30
|
+
TaskProgressColumn,
|
|
31
|
+
TextColumn,
|
|
32
|
+
TimeElapsedColumn,
|
|
33
|
+
TimeRemainingColumn,
|
|
34
|
+
)
|
|
35
|
+
from rich.table import Table
|
|
36
|
+
from rich.text import Text
|
|
37
|
+
|
|
38
|
+
from nextrec.utils.feature import as_float, normalize_to_list
|
|
39
|
+
|
|
40
|
+
T = TypeVar("T")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_nextrec_version() -> str:
|
|
44
|
+
"""
|
|
45
|
+
Best-effort version resolver for NextRec.
|
|
46
|
+
|
|
47
|
+
Prefer in-repo `nextrec.__version__`, fall back to installed package metadata.
|
|
48
|
+
"""
|
|
49
|
+
try:
|
|
50
|
+
from nextrec import __version__ # type: ignore
|
|
51
|
+
|
|
52
|
+
if __version__:
|
|
53
|
+
return str(__version__)
|
|
54
|
+
except Exception:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
from importlib.metadata import version
|
|
59
|
+
|
|
60
|
+
return version("nextrec")
|
|
61
|
+
except Exception:
|
|
62
|
+
return "unknown"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def log_startup_info(
|
|
66
|
+
logger: logging.Logger, *, mode: str, config_path: str | None
|
|
67
|
+
) -> None:
|
|
68
|
+
"""Log a short, user-friendly startup banner."""
|
|
69
|
+
version = get_nextrec_version()
|
|
70
|
+
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
71
|
+
|
|
72
|
+
lines = [
|
|
73
|
+
"NextRec CLI",
|
|
74
|
+
f"- Version: {version}",
|
|
75
|
+
f"- Time: {now}",
|
|
76
|
+
f"- Mode: {mode}",
|
|
77
|
+
f"- Config: {config_path or '(not set)'}",
|
|
78
|
+
f"- Python: {platform.python_version()} ({sys.executable})",
|
|
79
|
+
f"- Platform: {platform.system()} {platform.release()} ({platform.machine()})",
|
|
80
|
+
f"- Workdir: {os.getcwd()}",
|
|
81
|
+
f"- Command: {' '.join(sys.argv)}",
|
|
82
|
+
]
|
|
83
|
+
for line in lines:
|
|
84
|
+
logger.info(line)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class BlackTimeElapsedColumn(TimeElapsedColumn):
|
|
88
|
+
def render(self, task) -> Text:
|
|
89
|
+
elapsed = task.finished_time if task.finished else task.elapsed
|
|
90
|
+
if elapsed is None:
|
|
91
|
+
return Text("-:--:--", style="black")
|
|
92
|
+
delta = timedelta(seconds=max(0, int(elapsed)))
|
|
93
|
+
return Text(str(delta), style="black")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class BlackTimeRemainingColumn(TimeRemainingColumn):
|
|
97
|
+
def render(self, task) -> Text:
|
|
98
|
+
if self.elapsed_when_finished and task.finished:
|
|
99
|
+
task_time = task.finished_time
|
|
100
|
+
else:
|
|
101
|
+
task_time = task.time_remaining
|
|
102
|
+
|
|
103
|
+
if task.total is None:
|
|
104
|
+
return Text("", style="black")
|
|
105
|
+
|
|
106
|
+
if task_time is None:
|
|
107
|
+
return Text("--:--" if self.compact else "-:--:--", style="black")
|
|
108
|
+
|
|
109
|
+
minutes, seconds = divmod(int(task_time), 60)
|
|
110
|
+
hours, minutes = divmod(minutes, 60)
|
|
111
|
+
|
|
112
|
+
if self.compact and not hours:
|
|
113
|
+
formatted = f"{minutes:02d}:{seconds:02d}"
|
|
114
|
+
else:
|
|
115
|
+
formatted = f"{hours:d}:{minutes:02d}:{seconds:02d}"
|
|
116
|
+
|
|
117
|
+
return Text(formatted, style="black")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class BlackMofNCompleteColumn(MofNCompleteColumn):
|
|
121
|
+
def render(self, task) -> Text:
|
|
122
|
+
completed = int(task.completed)
|
|
123
|
+
total = int(task.total) if task.total is not None else "?"
|
|
124
|
+
total_width = len(str(total))
|
|
125
|
+
return Text(
|
|
126
|
+
f"{completed:{total_width}d}{self.separator}{total}",
|
|
127
|
+
style="black",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def progress(
|
|
132
|
+
iterable: Iterable[T],
|
|
133
|
+
*,
|
|
134
|
+
description: str | None = None,
|
|
135
|
+
total: int | None = None,
|
|
136
|
+
disable: bool = False,
|
|
137
|
+
) -> Iterable[T]:
|
|
138
|
+
if disable:
|
|
139
|
+
for item in iterable:
|
|
140
|
+
yield item
|
|
141
|
+
return
|
|
142
|
+
resolved_total = total
|
|
143
|
+
if resolved_total is None:
|
|
144
|
+
try:
|
|
145
|
+
resolved_total = len(iterable) # type: ignore[arg-type]
|
|
146
|
+
except TypeError:
|
|
147
|
+
resolved_total = None
|
|
148
|
+
|
|
149
|
+
progress_bar = Progress(
|
|
150
|
+
SpinnerColumn(style="black"),
|
|
151
|
+
TextColumn("{task.description}", style="black"),
|
|
152
|
+
BarColumn(
|
|
153
|
+
bar_width=36, style="black", complete_style="black", finished_style="black"
|
|
154
|
+
),
|
|
155
|
+
TaskProgressColumn(style="black"),
|
|
156
|
+
BlackMofNCompleteColumn(),
|
|
157
|
+
BlackTimeElapsedColumn(),
|
|
158
|
+
BlackTimeRemainingColumn(),
|
|
159
|
+
refresh_per_second=12,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
task_id = progress_bar.add_task(description or "Working", total=resolved_total)
|
|
163
|
+
progress_bar.start()
|
|
164
|
+
try:
|
|
165
|
+
for item in iterable:
|
|
166
|
+
yield item
|
|
167
|
+
progress_bar.advance(task_id, 1)
|
|
168
|
+
finally:
|
|
169
|
+
progress_bar.stop()
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def group_metrics_by_task(
|
|
173
|
+
metrics: Mapping[str, Any] | None,
|
|
174
|
+
target_names: list[str] | str | None,
|
|
175
|
+
default_task_name: str = "overall",
|
|
176
|
+
) -> tuple[list[str], dict[str, dict[str, float]]]:
|
|
177
|
+
if not metrics:
|
|
178
|
+
return [], {}
|
|
179
|
+
|
|
180
|
+
if isinstance(target_names, str):
|
|
181
|
+
target_names = [target_names]
|
|
182
|
+
if not isinstance(target_names, list) or not target_names:
|
|
183
|
+
target_names = [default_task_name]
|
|
184
|
+
|
|
185
|
+
targets_by_len = sorted(target_names, key=len, reverse=True)
|
|
186
|
+
grouped: dict[str, dict[str, float]] = {}
|
|
187
|
+
for key, raw_value in metrics.items():
|
|
188
|
+
value = as_float(raw_value)
|
|
189
|
+
if value is None:
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
matched_target: str | None = None
|
|
193
|
+
metric_name = key
|
|
194
|
+
for target in targets_by_len:
|
|
195
|
+
suffix = f"_{target}"
|
|
196
|
+
if key.endswith(suffix):
|
|
197
|
+
metric_name = key[: -len(suffix)]
|
|
198
|
+
matched_target = target
|
|
199
|
+
break
|
|
200
|
+
|
|
201
|
+
if matched_target is None:
|
|
202
|
+
matched_target = (
|
|
203
|
+
target_names[0] if len(target_names) == 1 else default_task_name
|
|
204
|
+
)
|
|
205
|
+
grouped.setdefault(matched_target, {})[metric_name] = value
|
|
206
|
+
|
|
207
|
+
task_order: list[str] = []
|
|
208
|
+
for target in target_names:
|
|
209
|
+
if target in grouped:
|
|
210
|
+
task_order.append(target)
|
|
211
|
+
for task_name in grouped:
|
|
212
|
+
if task_name not in task_order:
|
|
213
|
+
task_order.append(task_name)
|
|
214
|
+
return task_order, grouped
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def display_metrics_table(
|
|
218
|
+
epoch: int,
|
|
219
|
+
epochs: int,
|
|
220
|
+
split: str,
|
|
221
|
+
loss: float | np.ndarray | None,
|
|
222
|
+
metrics: Mapping[str, Any] | None,
|
|
223
|
+
target_names: list[str] | str | None,
|
|
224
|
+
base_metrics: list[str] | None = None,
|
|
225
|
+
is_main_process: bool = True,
|
|
226
|
+
colorize: Callable[[str], str] | None = None,
|
|
227
|
+
) -> None:
|
|
228
|
+
if not is_main_process:
|
|
229
|
+
return
|
|
230
|
+
|
|
231
|
+
target_list = normalize_to_list(target_names)
|
|
232
|
+
task_order, grouped = group_metrics_by_task(metrics, target_names=target_names)
|
|
233
|
+
|
|
234
|
+
if isinstance(loss, np.ndarray) and target_list:
|
|
235
|
+
# Ensure tasks with losses are shown even when metrics are missing for some targets.
|
|
236
|
+
normalized_order: list[str] = []
|
|
237
|
+
for name in target_list:
|
|
238
|
+
if name not in normalized_order:
|
|
239
|
+
normalized_order.append(name)
|
|
240
|
+
for name in task_order:
|
|
241
|
+
if name not in normalized_order:
|
|
242
|
+
normalized_order.append(name)
|
|
243
|
+
task_order = normalized_order
|
|
244
|
+
|
|
245
|
+
if Console is None or Table is None or box is None:
|
|
246
|
+
prefix = f"Epoch {epoch}/{epochs} - {split}:"
|
|
247
|
+
segments: list[str] = []
|
|
248
|
+
if isinstance(loss, numbers.Number):
|
|
249
|
+
segments.append(f"loss={float(loss):.4f}")
|
|
250
|
+
if task_order and grouped:
|
|
251
|
+
task_strs: list[str] = []
|
|
252
|
+
for task_name in task_order:
|
|
253
|
+
metric_items = grouped.get(task_name, {})
|
|
254
|
+
if not metric_items:
|
|
255
|
+
continue
|
|
256
|
+
metric_str = ", ".join(
|
|
257
|
+
f"{k}={float(v):.4f}" for k, v in metric_items.items()
|
|
258
|
+
)
|
|
259
|
+
task_strs.append(f"{task_name}[{metric_str}]")
|
|
260
|
+
if task_strs:
|
|
261
|
+
segments.append(", ".join(task_strs))
|
|
262
|
+
elif metrics:
|
|
263
|
+
metric_str = ", ".join(
|
|
264
|
+
f"{k}={float(v):.4f}"
|
|
265
|
+
for k, v in metrics.items()
|
|
266
|
+
if as_float(v) is not None
|
|
267
|
+
)
|
|
268
|
+
if metric_str:
|
|
269
|
+
segments.append(metric_str)
|
|
270
|
+
if not segments:
|
|
271
|
+
return
|
|
272
|
+
msg = f"{prefix} " + ", ".join(segments)
|
|
273
|
+
if colorize is not None:
|
|
274
|
+
msg = colorize(msg)
|
|
275
|
+
logging.info(msg)
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
title = f"Epoch {epoch}/{epochs} - {split}"
|
|
279
|
+
if isinstance(loss, numbers.Number):
|
|
280
|
+
title += f" (loss={float(loss):.4f})"
|
|
281
|
+
|
|
282
|
+
table = Table(
|
|
283
|
+
title=title,
|
|
284
|
+
box=box.ROUNDED,
|
|
285
|
+
header_style="bold",
|
|
286
|
+
title_style="bold",
|
|
287
|
+
)
|
|
288
|
+
table.add_column("Task", style="bold")
|
|
289
|
+
|
|
290
|
+
include_loss = isinstance(loss, np.ndarray)
|
|
291
|
+
if include_loss:
|
|
292
|
+
table.add_column("loss", justify="right")
|
|
293
|
+
|
|
294
|
+
metric_names: list[str] = []
|
|
295
|
+
for task_name in task_order:
|
|
296
|
+
for metric_name in grouped.get(task_name, {}):
|
|
297
|
+
if metric_name not in metric_names:
|
|
298
|
+
metric_names.append(metric_name)
|
|
299
|
+
|
|
300
|
+
preferred_order: list[str] = []
|
|
301
|
+
if isinstance(base_metrics, list):
|
|
302
|
+
preferred_order = [m for m in base_metrics if m in metric_names]
|
|
303
|
+
remaining = [m for m in metric_names if m not in preferred_order]
|
|
304
|
+
metric_names = preferred_order + sorted(remaining)
|
|
305
|
+
|
|
306
|
+
for metric_name in metric_names:
|
|
307
|
+
table.add_column(metric_name, justify="right")
|
|
308
|
+
|
|
309
|
+
def fmt(value: float | None) -> str:
|
|
310
|
+
if value is None:
|
|
311
|
+
return "-"
|
|
312
|
+
if np.isnan(value):
|
|
313
|
+
return "nan"
|
|
314
|
+
if np.isinf(value):
|
|
315
|
+
return "inf" if value > 0 else "-inf"
|
|
316
|
+
return f"{value:.4f}"
|
|
317
|
+
|
|
318
|
+
loss_by_task: dict[str, float] = {}
|
|
319
|
+
if isinstance(loss, np.ndarray):
|
|
320
|
+
if target_list:
|
|
321
|
+
for i, task_name in enumerate(target_list):
|
|
322
|
+
if i < loss.shape[0]:
|
|
323
|
+
loss_by_task[task_name] = float(loss[i])
|
|
324
|
+
if "overall" in task_order and "overall" not in loss_by_task:
|
|
325
|
+
loss_by_task["overall"] = float(np.sum(loss))
|
|
326
|
+
elif task_order:
|
|
327
|
+
for i, task_name in enumerate(task_order):
|
|
328
|
+
if i < loss.shape[0]:
|
|
329
|
+
loss_by_task[task_name] = float(loss[i])
|
|
330
|
+
else:
|
|
331
|
+
task_order = ["overall"]
|
|
332
|
+
loss_by_task["overall"] = float(np.sum(loss))
|
|
333
|
+
|
|
334
|
+
if not task_order:
|
|
335
|
+
task_order = ["__overall__"]
|
|
336
|
+
|
|
337
|
+
for task_name in task_order:
|
|
338
|
+
row: list[str] = [str(task_name)]
|
|
339
|
+
if include_loss:
|
|
340
|
+
row.append(fmt(loss_by_task.get(task_name)))
|
|
341
|
+
for metric_name in metric_names:
|
|
342
|
+
row.append(fmt(grouped.get(task_name, {}).get(metric_name)))
|
|
343
|
+
table.add_row(*row)
|
|
344
|
+
|
|
345
|
+
Console().print(table)
|
|
346
|
+
|
|
347
|
+
record_console = Console(file=io.StringIO(), record=True, width=120)
|
|
348
|
+
record_console.print(table)
|
|
349
|
+
table_text = record_console.export_text(styles=False).rstrip()
|
|
350
|
+
|
|
351
|
+
root_logger = logging.getLogger()
|
|
352
|
+
record = root_logger.makeRecord(
|
|
353
|
+
root_logger.name,
|
|
354
|
+
logging.INFO,
|
|
355
|
+
__file__,
|
|
356
|
+
0,
|
|
357
|
+
"[MetricsTable]\n" + table_text,
|
|
358
|
+
args=(),
|
|
359
|
+
exc_info=None,
|
|
360
|
+
extra=None,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
emitted = False
|
|
364
|
+
for handler in root_logger.handlers:
|
|
365
|
+
if isinstance(handler, logging.FileHandler):
|
|
366
|
+
handler.emit(record)
|
|
367
|
+
emitted = True
|
|
368
|
+
|
|
369
|
+
if not emitted:
|
|
370
|
+
# Fallback: no file handlers configured, use standard logging.
|
|
371
|
+
root_logger.log(logging.INFO, "[MetricsTable]\n" + table_text)
|