nextrec 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/utils/file.py ADDED
@@ -0,0 +1,70 @@
1
+ """
2
+ File I/O utilities for NextRec
3
+
4
+ Date: create on 03/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+
8
+ import pandas as pd
9
+ import pyarrow.parquet as pq
10
+ from pathlib import Path
11
+ from typing import Generator
12
+
13
+
14
+ def resolve_file_paths(path: str) -> tuple[list[str], str]:
15
+ """
16
+ Resolve file or directory path into a sorted list of files and file type.
17
+
18
+ Args: path: Path to a file or directory
19
+ Returns: tuple: (list of file paths, file type)
20
+ """
21
+ path_obj = Path(path)
22
+
23
+ if path_obj.is_file():
24
+ file_type = path_obj.suffix.lower().lstrip(".")
25
+ assert file_type in ["csv", "parquet"], f"Unsupported file extension: {file_type}"
26
+ return [str(path_obj)], file_type
27
+
28
+ if path_obj.is_dir():
29
+ collected_files = [p for p in path_obj.iterdir() if p.is_file()]
30
+ csv_files = [str(p) for p in collected_files if p.suffix.lower() == ".csv"]
31
+ parquet_files = [str(p) for p in collected_files if p.suffix.lower() == ".parquet"]
32
+
33
+ if csv_files and parquet_files:
34
+ raise ValueError("Directory contains both CSV and Parquet files. Please keep a single format.")
35
+ file_paths = csv_files if csv_files else parquet_files
36
+ if not file_paths:
37
+ raise ValueError(f"No CSV or Parquet files found in directory: {path}")
38
+ file_paths.sort()
39
+ file_type = "csv" if csv_files else "parquet"
40
+ return file_paths, file_type
41
+
42
+ raise ValueError(f"Invalid path: {path}")
43
+
44
+
45
+ def read_table(file_path: str, file_type: str) -> pd.DataFrame:
46
+ if file_type == "csv":
47
+ return pd.read_csv(file_path)
48
+ return pd.read_parquet(file_path)
49
+
50
+ def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]:
51
+ return [read_table(fp, file_type) for fp in file_paths]
52
+
53
+ def iter_file_chunks(
54
+ file_path: str,
55
+ file_type: str,
56
+ chunk_size: int
57
+ ) -> Generator[pd.DataFrame, None, None]:
58
+ if file_type == "csv":
59
+ yield from pd.read_csv(file_path, chunksize=chunk_size)
60
+ return
61
+ parquet_file = pq.ParquetFile(file_path)
62
+ for batch in parquet_file.iter_batches(batch_size=chunk_size):
63
+ yield batch.to_pandas()
64
+
65
+
66
+ def default_output_dir(path: str) -> Path:
67
+ path_obj = Path(path)
68
+ if path_obj.is_file():
69
+ return path_obj.parent / f"{path_obj.stem}_preprocessed"
70
+ return path_obj.with_name(f"{path_obj.name}_preprocessed")
@@ -9,14 +9,6 @@ import torch.nn as nn
9
9
 
10
10
 
11
11
  def get_initializer(init_type='normal', activation='linear', param=None):
12
- """
13
- Get parameter initialization function.
14
-
15
- Examples:
16
- >>> init_fn = get_initializer('xavier_uniform', 'relu')
17
- >>> init_fn(tensor)
18
- >>> init_fn = get_initializer('normal', param={'mean': 0.0, 'std': 0.01})
19
- """
20
12
  param = param or {}
21
13
 
22
14
  try:
nextrec/utils/model.py ADDED
@@ -0,0 +1,22 @@
1
+ """
2
+ Model-related utilities for NextRec
3
+
4
+ Date: create on 03/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+
8
+ from collections import OrderedDict
9
+
10
+
11
+ def merge_features(primary, secondary) -> list:
12
+ merged: OrderedDict[str, object] = OrderedDict()
13
+ for feat in list(primary or []) + list(secondary or []):
14
+ merged.setdefault(feat.name, feat)
15
+ return list(merged.values())
16
+
17
+
18
+ def get_mlp_output_dim(params: dict, fallback: int) -> int:
19
+ dims = params.get("dims")
20
+ if dims:
21
+ return dims[-1]
22
+ return fallback
@@ -8,25 +8,16 @@ Author: Yang Zhou, zyaztec@gmail.com
8
8
  import torch
9
9
  from typing import Iterable
10
10
 
11
-
12
11
  def get_optimizer(
13
12
  optimizer: str | torch.optim.Optimizer = "adam",
14
13
  params: Iterable[torch.nn.Parameter] | None = None,
15
14
  **optimizer_params
16
15
  ):
17
- """
18
- Get optimizer function based on optimizer name or instance.
19
-
20
- Examples:
21
- >>> optimizer = get_optimizer("adam", model.parameters(), lr=1e-3)
22
- >>> optimizer = get_optimizer("sgd", model.parameters(), lr=0.01, momentum=0.9)
23
- """
24
16
  if params is None:
25
17
  raise ValueError("params cannot be None. Please provide model parameters.")
26
18
 
27
19
  if 'lr' not in optimizer_params:
28
20
  optimizer_params['lr'] = 1e-3
29
-
30
21
  if isinstance(optimizer, str):
31
22
  opt_name = optimizer.lower()
32
23
  if opt_name == "adam":
@@ -42,27 +33,17 @@ def get_optimizer(
42
33
  else:
43
34
  raise NotImplementedError(f"Unsupported optimizer: {optimizer}")
44
35
  optimizer_fn = opt_class(params=params, **optimizer_params)
45
-
46
36
  elif isinstance(optimizer, torch.optim.Optimizer):
47
37
  optimizer_fn = optimizer
48
38
  else:
49
39
  raise TypeError(f"Invalid optimizer type: {type(optimizer)}")
50
-
51
40
  return optimizer_fn
52
41
 
53
-
54
42
  def get_scheduler(
55
43
  scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None,
56
44
  optimizer,
57
45
  **scheduler_params
58
46
  ):
59
- """
60
- Get learning rate scheduler function.
61
-
62
- Examples:
63
- >>> scheduler = get_scheduler("step", optimizer, step_size=10, gamma=0.1)
64
- >>> scheduler = get_scheduler("cosine", optimizer, T_max=100)
65
- """
66
47
  if isinstance(scheduler, str):
67
48
  if scheduler == "step":
68
49
  scheduler_fn = torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_params)
@@ -0,0 +1,61 @@
1
+ """
2
+ Tensor manipulation utilities for NextRec
3
+
4
+ Date: create on 03/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from typing import Any
11
+
12
+
13
+ def to_tensor(
14
+ value: Any,
15
+ dtype: torch.dtype,
16
+ device: torch.device | str | None = None
17
+ ) -> torch.Tensor:
18
+ if value is None:
19
+ raise ValueError("[Tensor Utils Error] Cannot convert None to tensor.")
20
+ tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
21
+ if tensor.dtype != dtype:
22
+ tensor = tensor.to(dtype=dtype)
23
+
24
+ if device is not None:
25
+ target_device = device if isinstance(device, torch.device) else torch.device(device)
26
+ if tensor.device != target_device:
27
+ tensor = tensor.to(target_device)
28
+ return tensor
29
+
30
+ def stack_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
31
+ if not tensors:
32
+ raise ValueError("[Tensor Utils Error] Cannot stack empty list of tensors.")
33
+ return torch.stack(tensors, dim=dim)
34
+
35
+ def concat_tensors(tensors: list[torch.Tensor], dim: int = 0) -> torch.Tensor:
36
+ if not tensors:
37
+ raise ValueError("[Tensor Utils Error] Cannot concatenate empty list of tensors.")
38
+ return torch.cat(tensors, dim=dim)
39
+
40
+ def pad_sequence_tensors(
41
+ tensors: list[torch.Tensor],
42
+ max_len: int | None = None,
43
+ padding_value: float = 0.0,
44
+ padding_side: str = 'right'
45
+ ) -> torch.Tensor:
46
+ if not tensors:
47
+ raise ValueError("[Tensor Utils Error] Cannot pad empty list of tensors.")
48
+ if max_len is None:
49
+ max_len = max(t.size(0) for t in tensors)
50
+ batch_size = len(tensors)
51
+ padded = torch.full((batch_size, max_len), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)
52
+
53
+ for i, tensor in enumerate(tensors):
54
+ length = min(tensor.size(0), max_len)
55
+ if padding_side == 'right':
56
+ padded[i, :length] = tensor[:length]
57
+ elif padding_side == 'left':
58
+ padded[i, -length:] = tensor[:length]
59
+ else:
60
+ raise ValueError(f"[Tensor Utils Error] padding_side must be 'right' or 'left', got {padding_side}")
61
+ return padded
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.3.4
3
+ Version: 0.3.6
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
63
63
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
64
64
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
65
65
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
66
- ![Version](https://img.shields.io/badge/Version-0.3.4-orange.svg)
66
+ ![Version](https://img.shields.io/badge/Version-0.3.6-orange.svg)
67
67
 
68
68
  English | [中文文档](README_zh.md)
69
69
 
@@ -110,7 +110,7 @@ To dive deeper, Jupyter notebooks are available:
110
110
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
111
111
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
112
112
 
113
- > Current version [0.3.4]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
113
+ > Current version [0.3.6]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
114
114
 
115
115
  ## 5-Minute Quick Start
116
116
 
@@ -1,38 +1,41 @@
1
- nextrec/__init__.py,sha256=CvocnY2uBp0cjNkhrT6ogw0q2bN9s1GNp754FLO-7lo,1117
2
- nextrec/__version__.py,sha256=oYLGMpySamd16KLiaBTfRyrAS7_oyp-TOEHmzmeumwg,22
1
+ nextrec/__init__.py,sha256=nFRpUAjezaxyMJDTgy4g9PtpDTq28sMHleSrlg3QkVA,235
2
+ nextrec/__version__.py,sha256=W_9dCm49nLvZulVAvvsafxLJjVBSKDBHz9K7szFZllo,22
3
3
  nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  nextrec/basic/activation.py,sha256=1qs9pq4hT3BUxIiYdYs57axMCm4-JyOBFQ6x7xkHTwM,2849
5
5
  nextrec/basic/callback.py,sha256=wwh0I2kKYyywCB-sG9eQXShlpXFJIo75qApJmnI5p6c,1036
6
- nextrec/basic/features.py,sha256=-RRRbEPU-SFI-GtppflW6O0bKShUsV-Hg_lTGpo3AIE,4262
6
+ nextrec/basic/features.py,sha256=DFwYjG13GYHOujS_CMKa7Qrux9faF7MQNoaoRDF_Eks,4263
7
7
  nextrec/basic/layers.py,sha256=zzEseKYVnMVs1Tg5EGrFimugId15jI6HumgzjFyRqgw,23127
8
- nextrec/basic/loggers.py,sha256=hh9tRMmaCTaJ_sfRHIlbcqd6BcpK63vpZ_21TFCiKLI,6148
8
+ nextrec/basic/loggers.py,sha256=YLmeXsnzm9M2qxtmBOLMGZRg9wOAUQYl8UNpbWFzs8s,6147
9
9
  nextrec/basic/metrics.py,sha256=8-hMZJXU5L4F8GnToxMZey5dlBrtFyRtTuI_zoQCtIo,21579
10
- nextrec/basic/model.py,sha256=afnvicyxXMgWdvhrIUaoNnZ7S-QYRYr7fTY5bdM1u_s,68829
11
- nextrec/basic/session.py,sha256=oaATn-nzbJ9A6SGbMut9xLV_NSh9_1KmVDeNauS06Ps,4767
12
- nextrec/data/__init__.py,sha256=6WgXZafzzXcv5kuxKNi67O8BJZVl_P_HM2IZCDIIhPA,1052
13
- nextrec/data/data_utils.py,sha256=aOyja3Yu7O2c8eIeL3P8MyUlUR5EerOUT9UeF4ATq8o,10574
14
- nextrec/data/dataloader.py,sha256=2MLe69y0E1cTZyzMNgyLUCxa6lllGd1ntvwpXzxdX10,14199
15
- nextrec/data/preprocessor.py,sha256=lhigpjvkEqsjTRfbBBOjgGOxoPyOifwq2LoswgyIVqc,40488
10
+ nextrec/basic/model.py,sha256=LybJlpzK2S6zw8ez_HrR_tFc15Gzcy0t4GMD12i9sA0,69310
11
+ nextrec/basic/session.py,sha256=kYpUE6KzN2_Jli4l-YuoeMBaghGi3kzDnGRP3E08FbQ,4430
12
+ nextrec/data/__init__.py,sha256=OJsuESaE0NZorAkAwydWJtsWsbNBzKfmQCrDJTzA5a0,1227
13
+ nextrec/data/batch_utils.py,sha256=6G-E85H-PqYJ20EYVLnC3MqC8xYrXzZ1XYe82MhRPck,2816
14
+ nextrec/data/data_processing.py,sha256=N3Uk4NsUCyLeoMDV1zeLmH-dP02I-cRWDo-vvQgLqjo,5006
15
+ nextrec/data/data_utils.py,sha256=-3xLPW3csOiGNmj0kzzpOkCxZyu09RNBgfPkwX7nDAc,1172
16
+ nextrec/data/dataloader.py,sha256=JOudvhnMcNBYFlSKbMKi43Ndn2c1kGoyD8G9gTlW0Ps,14699
17
+ nextrec/data/preprocessor.py,sha256=_A3eEc1MpUGDEpno1TToA-dyJ_k707Mr3GilTi_9j5I,40419
16
18
  nextrec/loss/__init__.py,sha256=mO5t417BneZ8Ysa51GyjDaffjWyjzFgPXIQrrggasaQ,827
17
19
  nextrec/loss/listwise.py,sha256=gxDbO1td5IeS28jKzdE35o1KAYBRdCYoMzyZzfNLhc0,5689
18
20
  nextrec/loss/loss_utils.py,sha256=uZ4m9ChLr-UgIc5Yxm1LjwXDDepApQ-Fas8njweZ9qg,2641
19
21
  nextrec/loss/pairwise.py,sha256=MN_3Pk6Nj8KCkmUqGT5cmyx1_nQa3TIx_kxXT_HB58c,3396
20
22
  nextrec/loss/pointwise.py,sha256=shgdRJwTV7vAnVxHSffOJU4TPQeKyrwudQ8y-R10nYM,7144
21
- nextrec/models/generative/__init__.py,sha256=vo8-DloD74cKc1moSH-4GYG99w8Yi8YPGPxh8XDJPoc,50
23
+ nextrec/models/generative/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
24
  nextrec/models/generative/hstu.py,sha256=CLu8Ee_L4fdnb7_DKocz0g7SZlPI1g_6o8HtyzRkI9s,16368
23
25
  nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
- nextrec/models/match/__init__.py,sha256=ASZB5abqKPhDbk8NErNNNa0DHuWpsVxvUtyEn5XMx6Y,215
26
+ nextrec/models/match/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
27
  nextrec/models/match/dssm.py,sha256=1cj1Fb3yFTKxA1kRaomh_Q8y66vBZc85ywAIXWosyL0,8230
26
28
  nextrec/models/match/dssm_v2.py,sha256=BY6m9651NlzMLjRa9oeez0dab_3NjNFVgYQ7Q39Ug74,7187
27
29
  nextrec/models/match/mind.py,sha256=0cggXKE1-XsTZ6IX4UH81a5KycdGF-_ix2Nw-eKTLpg,14886
28
30
  nextrec/models/match/sdm.py,sha256=wVRj6PWLF6hMIqqlJDUuKqxJAvCGPe-HfD3EVgd16Sw,10918
29
31
  nextrec/models/match/youtube_dnn.py,sha256=Wa5JWrlIpMuBoyXpnBrdnm1nQ8ZO_XcR517zfINh-xA,7544
32
+ nextrec/models/multi_task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
33
  nextrec/models/multi_task/esmm.py,sha256=Ho5UN2H9H9-ZYML6eqpBYTVdTO4Ja9AoYP5SSgsgQaw,6442
31
34
  nextrec/models/multi_task/mmoe.py,sha256=zfBAUoQijHCuat962dZI0MCAy8C6PZqZ-zOd16JznF8,7803
32
35
  nextrec/models/multi_task/ple.py,sha256=zNBea0sfJska36RVH1N9O92m7rPmbaWYqoPbnGoy1RE,11949
33
- nextrec/models/multi_task/poso.py,sha256=_yLiCkD3NhOZEOWx-jP4MJxSEdNCu3mqeo_XRt8CWts,16652
36
+ nextrec/models/multi_task/poso.py,sha256=_Pq-cl7HB1uQVO8HXreNeVpQso250ouxBNTsdTjyFos,16651
34
37
  nextrec/models/multi_task/share_bottom.py,sha256=kvrkXQSTDPEwwmBvXw3xryBm3gT8Uq4_Hb3TenwRj9w,5920
35
- nextrec/models/ranking/__init__.py,sha256=AY806x-2BtltQdlR4wu23-keL9YUe3An92OJshS4t9Y,472
38
+ nextrec/models/ranking/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
39
  nextrec/models/ranking/afm.py,sha256=uFSUIv9d6NQkCiM2epmSdMy4kxjFuCRVbrZOv3nebGE,4539
37
40
  nextrec/models/ranking/autoint.py,sha256=MN6Dv6EMK0ODsCEeX4iXBSDoxK9a_DxIdEduGAUIVEQ,7771
38
41
  nextrec/models/ranking/dcn.py,sha256=sy0v_kMQ1OfVCFuRD8FDrroQHm-RnTx4lVacfXfs2N8,4896
@@ -46,12 +49,16 @@ nextrec/models/ranking/masknet.py,sha256=9K6XKcr8f8PcVhLfgFd8l4tq78lcclAQAXZKlVE
46
49
  nextrec/models/ranking/pnn.py,sha256=eEyBnALuzaNx27iGJ0ZqNcf0u7dKN8SiO03lkcv1hiE,4956
47
50
  nextrec/models/ranking/widedeep.py,sha256=AJPkoThUTSBGPNBjD-aiWsMH2hSiSnGLjIPy_2neNhc,5034
48
51
  nextrec/models/ranking/xdeepfm.py,sha256=wn6YnX78EyBzil7IRBcqyDqsnysERVJ5-lWGuRMCpxE,5681
49
- nextrec/utils/__init__.py,sha256=ciw6B9SXffjSb4cwco-WXpKSE7M9D6ILpLZ2oftwj6A,457
50
- nextrec/utils/common.py,sha256=NYXnBVtUCtm8epT2ZxJHn_m1SIBBI_PEjZ5VpL465ls,2009
52
+ nextrec/utils/__init__.py,sha256=lAVpHsGe_WgGf7R-K1wr0DeVLvskG0Bj1L12N6kEPwM,1810
53
+ nextrec/utils/device.py,sha256=1QtmlpxRSHiuYfmCOQCOIk-s6bmjIpoJtzfXOcvginI,1044
51
54
  nextrec/utils/embedding.py,sha256=yxYSdFx0cJITh3Gf-K4SdhwRtKGcI0jOsyBgZ0NLa_c,465
52
- nextrec/utils/initializer.py,sha256=ffYOs5QuIns_d_-5e40iNtg6s1ftgREJN-ueq_NbDQE,1647
53
- nextrec/utils/optimizer.py,sha256=EUjAGFPeyou_Cv-_2HRvjzut8y_qpAQudc8L2T0k8zw,2706
54
- nextrec-0.3.4.dist-info/METADATA,sha256=X5fo5gymQdPXLgM1N03E58uFSQyuQOmdbUp8vXvKl0g,16319
55
- nextrec-0.3.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
56
- nextrec-0.3.4.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
57
- nextrec-0.3.4.dist-info/RECORD,,
55
+ nextrec/utils/feature.py,sha256=s0eMEuvbOsotjll7eSYjb0b-1cXnvVy1mSI1Syg_7n4,299
56
+ nextrec/utils/file.py,sha256=wxKvd1_U9ugFDP7EzLNG6-3PBInA0QhxoHzBWKfe_B8,2384
57
+ nextrec/utils/initializer.py,sha256=BkP6-vJdsc0A-8ya-AVEs7W24dPXyxIilNnckwXgPEc,1391
58
+ nextrec/utils/model.py,sha256=FB7QbatO0uEvghBEfByJtRS0waaBEB1UI0YzfA_2k04,535
59
+ nextrec/utils/optimizer.py,sha256=cVkDrEkxwig17UAEhL8p9v3iVNiXI8B067Yf_6LqUp8,2198
60
+ nextrec/utils/tensor.py,sha256=_RibR6BMPizhzRLVdnJqwUgzA0zpzkZuKfTrdSjbL60,2136
61
+ nextrec-0.3.6.dist-info/METADATA,sha256=yq_cvYiBZzWJcZaIlBornYCW_Hc8v7p01mhRuN15jOk,16319
62
+ nextrec-0.3.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
63
+ nextrec-0.3.6.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
64
+ nextrec-0.3.6.dist-info/RECORD,,
nextrec/utils/common.py DELETED
@@ -1,60 +0,0 @@
1
- import torch
2
- import platform
3
- from collections import OrderedDict
4
-
5
-
6
- def resolve_device() -> str:
7
- """Select a usable device with graceful fallback."""
8
- if torch.cuda.is_available():
9
- return "cuda"
10
- if torch.backends.mps.is_available():
11
- mac_ver = platform.mac_ver()[0]
12
- try:
13
- major, minor = (int(x) for x in mac_ver.split(".")[:2])
14
- except Exception:
15
- major, minor = 0, 0
16
- if major >= 14:
17
- return "mps"
18
- return "cpu"
19
-
20
-
21
- def normalize_to_list(value: str | list[str] | None) -> list[str]:
22
- if value is None:
23
- return []
24
- if isinstance(value, str):
25
- return [value]
26
- return list(value)
27
-
28
-
29
- def merge_features(primary, secondary) -> list:
30
- """
31
- Merge two feature lists while preserving order and deduplicating by feature name.
32
- Later duplicates are skipped.
33
- """
34
- merged: OrderedDict[str, object] = OrderedDict()
35
- for feat in list(primary or []) + list(secondary or []):
36
- merged.setdefault(feat.name, feat)
37
- return list(merged.values())
38
-
39
- def get_mlp_output_dim(params: dict, fallback: int) -> int:
40
- """
41
- Get the output dimension of an MLP-like config.
42
- If dims are provided, use the last dim; otherwise fall back to input dim.
43
- """
44
- dims = params.get("dims")
45
- if dims:
46
- return dims[-1]
47
- return fallback
48
-
49
- def to_tensor(value, dtype: torch.dtype, device: torch.device | str | None = None) -> torch.Tensor:
50
- """Convert any value to a tensor with the desired dtype/device."""
51
- if value is None:
52
- raise ValueError("[Tensor Utils Error] Cannot convert None to tensor.")
53
- tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
54
- if tensor.dtype != dtype:
55
- tensor = tensor.to(dtype=dtype)
56
- if device is not None:
57
- target_device = device if isinstance(device, torch.device) else torch.device(device)
58
- if tensor.device != target_device:
59
- tensor = tensor.to(target_device)
60
- return tensor