recnexteval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. recnexteval/__init__.py +20 -0
  2. recnexteval/algorithms/__init__.py +99 -0
  3. recnexteval/algorithms/base.py +377 -0
  4. recnexteval/algorithms/baseline/__init__.py +10 -0
  5. recnexteval/algorithms/baseline/decay_popularity.py +110 -0
  6. recnexteval/algorithms/baseline/most_popular.py +72 -0
  7. recnexteval/algorithms/baseline/random.py +39 -0
  8. recnexteval/algorithms/baseline/recent_popularity.py +34 -0
  9. recnexteval/algorithms/itemknn/__init__.py +14 -0
  10. recnexteval/algorithms/itemknn/itemknn.py +119 -0
  11. recnexteval/algorithms/itemknn/itemknn_incremental.py +65 -0
  12. recnexteval/algorithms/itemknn/itemknn_incremental_movielens.py +95 -0
  13. recnexteval/algorithms/itemknn/itemknn_rolling.py +17 -0
  14. recnexteval/algorithms/itemknn/itemknn_static.py +31 -0
  15. recnexteval/algorithms/time_aware_item_knn/__init__.py +11 -0
  16. recnexteval/algorithms/time_aware_item_knn/base.py +248 -0
  17. recnexteval/algorithms/time_aware_item_knn/decay_functions.py +260 -0
  18. recnexteval/algorithms/time_aware_item_knn/ding_2005.py +52 -0
  19. recnexteval/algorithms/time_aware_item_knn/liu_2010.py +65 -0
  20. recnexteval/algorithms/time_aware_item_knn/similarity_functions.py +106 -0
  21. recnexteval/algorithms/time_aware_item_knn/top_k.py +61 -0
  22. recnexteval/algorithms/time_aware_item_knn/utils.py +47 -0
  23. recnexteval/algorithms/time_aware_item_knn/vaz_2013.py +50 -0
  24. recnexteval/algorithms/utils.py +51 -0
  25. recnexteval/datasets/__init__.py +109 -0
  26. recnexteval/datasets/base.py +316 -0
  27. recnexteval/datasets/config/__init__.py +113 -0
  28. recnexteval/datasets/config/amazon.py +188 -0
  29. recnexteval/datasets/config/base.py +72 -0
  30. recnexteval/datasets/config/lastfm.py +105 -0
  31. recnexteval/datasets/config/movielens.py +169 -0
  32. recnexteval/datasets/config/yelp.py +25 -0
  33. recnexteval/datasets/datasets/__init__.py +24 -0
  34. recnexteval/datasets/datasets/amazon.py +151 -0
  35. recnexteval/datasets/datasets/base.py +250 -0
  36. recnexteval/datasets/datasets/lastfm.py +121 -0
  37. recnexteval/datasets/datasets/movielens.py +93 -0
  38. recnexteval/datasets/datasets/test.py +46 -0
  39. recnexteval/datasets/datasets/yelp.py +103 -0
  40. recnexteval/datasets/metadata/__init__.py +58 -0
  41. recnexteval/datasets/metadata/amazon.py +68 -0
  42. recnexteval/datasets/metadata/base.py +38 -0
  43. recnexteval/datasets/metadata/lastfm.py +110 -0
  44. recnexteval/datasets/metadata/movielens.py +87 -0
  45. recnexteval/evaluators/__init__.py +189 -0
  46. recnexteval/evaluators/accumulator.py +167 -0
  47. recnexteval/evaluators/base.py +216 -0
  48. recnexteval/evaluators/builder/__init__.py +125 -0
  49. recnexteval/evaluators/builder/base.py +166 -0
  50. recnexteval/evaluators/builder/pipeline.py +111 -0
  51. recnexteval/evaluators/builder/stream.py +54 -0
  52. recnexteval/evaluators/evaluator_pipeline.py +287 -0
  53. recnexteval/evaluators/evaluator_stream.py +374 -0
  54. recnexteval/evaluators/state_management.py +310 -0
  55. recnexteval/evaluators/strategy.py +32 -0
  56. recnexteval/evaluators/util.py +124 -0
  57. recnexteval/matrix/__init__.py +48 -0
  58. recnexteval/matrix/exception.py +5 -0
  59. recnexteval/matrix/interaction_matrix.py +784 -0
  60. recnexteval/matrix/prediction_matrix.py +153 -0
  61. recnexteval/matrix/util.py +24 -0
  62. recnexteval/metrics/__init__.py +57 -0
  63. recnexteval/metrics/binary/__init__.py +4 -0
  64. recnexteval/metrics/binary/hit.py +49 -0
  65. recnexteval/metrics/core/__init__.py +10 -0
  66. recnexteval/metrics/core/base.py +126 -0
  67. recnexteval/metrics/core/elementwise_top_k.py +75 -0
  68. recnexteval/metrics/core/listwise_top_k.py +72 -0
  69. recnexteval/metrics/core/top_k.py +60 -0
  70. recnexteval/metrics/core/util.py +29 -0
  71. recnexteval/metrics/ranking/__init__.py +6 -0
  72. recnexteval/metrics/ranking/dcg.py +55 -0
  73. recnexteval/metrics/ranking/ndcg.py +78 -0
  74. recnexteval/metrics/ranking/precision.py +51 -0
  75. recnexteval/metrics/ranking/recall.py +42 -0
  76. recnexteval/models/__init__.py +4 -0
  77. recnexteval/models/base.py +69 -0
  78. recnexteval/preprocessing/__init__.py +37 -0
  79. recnexteval/preprocessing/filter.py +181 -0
  80. recnexteval/preprocessing/preprocessor.py +137 -0
  81. recnexteval/registries/__init__.py +67 -0
  82. recnexteval/registries/algorithm.py +68 -0
  83. recnexteval/registries/base.py +131 -0
  84. recnexteval/registries/dataset.py +37 -0
  85. recnexteval/registries/metric.py +57 -0
  86. recnexteval/settings/__init__.py +127 -0
  87. recnexteval/settings/base.py +414 -0
  88. recnexteval/settings/exception.py +8 -0
  89. recnexteval/settings/leave_n_out_setting.py +48 -0
  90. recnexteval/settings/processor.py +115 -0
  91. recnexteval/settings/schema.py +11 -0
  92. recnexteval/settings/single_time_point_setting.py +111 -0
  93. recnexteval/settings/sliding_window_setting.py +153 -0
  94. recnexteval/settings/splitters/__init__.py +14 -0
  95. recnexteval/settings/splitters/base.py +57 -0
  96. recnexteval/settings/splitters/n_last.py +39 -0
  97. recnexteval/settings/splitters/n_last_timestamp.py +76 -0
  98. recnexteval/settings/splitters/timestamp.py +82 -0
  99. recnexteval/settings/util.py +0 -0
  100. recnexteval/utils/__init__.py +115 -0
  101. recnexteval/utils/json_to_csv_converter.py +128 -0
  102. recnexteval/utils/logging_tools.py +159 -0
  103. recnexteval/utils/path.py +155 -0
  104. recnexteval/utils/url_certificate_installer.py +54 -0
  105. recnexteval/utils/util.py +166 -0
  106. recnexteval/utils/uuid_util.py +7 -0
  107. recnexteval/utils/yaml_tool.py +65 -0
  108. recnexteval-0.1.0.dist-info/METADATA +85 -0
  109. recnexteval-0.1.0.dist-info/RECORD +110 -0
  110. recnexteval-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,128 @@
1
+ # Reference: https://github.com/Yelp/dataset-examples/blob/master/json_to_csv_converter.py
2
+
3
+ # -*- coding: utf-8 -*-
4
+ """Convert the Yelp Dataset Challenge dataset from json format to csv.
5
+
6
+ For more information on the Yelp Dataset Challenge please visit http://yelp.com/dataset_challenge
7
+
8
+ """
9
+
10
+ import argparse
11
+ import csv
12
+ from collections.abc import MutableMapping
13
+
14
+ import simplejson as json
15
+
16
+
17
+ def read_and_write_file(json_file_path, csv_file_path, column_names) -> None:
18
+ """Read in the json dataset file and write it out to a csv file, given the column names."""
19
+ with open(csv_file_path, "w") as fout:
20
+ csv_file = csv.writer(fout)
21
+ csv_file.writerow(list(column_names))
22
+ with open(json_file_path, encoding="utf-8") as fin:
23
+ for line in fin:
24
+ line_contents = json.loads(line)
25
+ csv_file.writerow(get_row(line_contents, column_names))
26
+
27
+
28
+ def get_superset_of_column_names_from_file(json_file_path):
29
+ """Read in the json dataset file and return the superset of column names."""
30
+ column_names = set()
31
+ with open(json_file_path, encoding="utf-8") as fin:
32
+ for line in fin:
33
+ line_contents = json.loads(line)
34
+ column_names.update(set(get_column_names(line_contents).keys()))
35
+ return column_names
36
+
37
+
38
+ def get_column_names(line_contents, parent_key="") -> dict:
39
+ """Return a list of flattened key names given a dict.
40
+
41
+ Example:
42
+
43
+ line_contents = {
44
+ 'a': {
45
+ 'b': 2,
46
+ 'c': 3,
47
+ },
48
+ }
49
+
50
+ will return: ['a.b', 'a.c']
51
+
52
+ These will be the column names for the eventual csv file.
53
+
54
+ """
55
+ column_names = []
56
+ for k, v in line_contents.items():
57
+ column_name = "{0}.{1}".format(parent_key, k) if parent_key else k
58
+ if isinstance(v, MutableMapping):
59
+ column_names.extend(get_column_names(v, column_name).items())
60
+ else:
61
+ column_names.append((column_name, v))
62
+ return dict(column_names)
63
+
64
+
65
+ def get_nested_value(d, key):
66
+ """Return a dictionary item given a dictionary `d` and a flattened key from `get_column_names`.
67
+
68
+ Example:
69
+
70
+ d = {
71
+ 'a': {
72
+ 'b': 2,
73
+ 'c': 3,
74
+ },
75
+ }
76
+ key = 'a.b'
77
+
78
+ will return: 2
79
+
80
+ """
81
+ if "." not in key:
82
+ if key not in d:
83
+ return None
84
+ return d[key]
85
+ base_key, sub_key = key.split(".", 1)
86
+ if base_key not in d:
87
+ return None
88
+ sub_dict = d[base_key]
89
+ return get_nested_value(sub_dict, sub_key)
90
+
91
+
92
+ def get_row(line_contents, column_names):
93
+ """Return a csv compatible row given column names and a dict."""
94
+ row = []
95
+ for column_name in column_names:
96
+ line_value = get_nested_value(
97
+ line_contents,
98
+ column_name,
99
+ )
100
+ if isinstance(line_value, str):
101
+ row.append("{0}".format(line_value.encode("utf-8")))
102
+ elif line_value is not None:
103
+ row.append("{0}".format(line_value))
104
+ else:
105
+ row.append("")
106
+ return row
107
+
108
+
109
+ if __name__ == "__main__":
110
+ """Convert a yelp dataset file from json to csv."""
111
+
112
+ parser = argparse.ArgumentParser(
113
+ description="Convert Yelp Dataset Challenge data from JSON format to CSV.",
114
+ )
115
+
116
+ parser.add_argument(
117
+ "json_file",
118
+ type=str,
119
+ help="The json file to convert.",
120
+ )
121
+
122
+ args = parser.parse_args()
123
+
124
+ json_file = args.json_file
125
+ csv_file = "{0}.csv".format(json_file.split(".json")[0])
126
+
127
+ column_names = get_superset_of_column_names_from_file(json_file)
128
+ read_and_write_file(json_file, csv_file, column_names)
@@ -0,0 +1,159 @@
1
+ import logging
2
+ import logging.config
3
+ import os
4
+ import warnings
5
+ from collections.abc import Generator
6
+ from contextlib import contextmanager
7
+ from enum import Enum
8
+
9
+ import pyfiglet
10
+ import yaml
11
+
12
+ from recnexteval.utils.path import safe_dir
13
+ from recnexteval.utils.yaml_tool import create_config_yaml
14
+
15
+
16
+ class LogLevel(Enum):
17
+ CRITICAL = logging.CRITICAL
18
+ ERROR = logging.ERROR
19
+ WARNING = logging.WARNING
20
+ INFO = logging.INFO
21
+ DEBUG = logging.DEBUG
22
+ NOTSET = logging.NOTSET
23
+
24
+ @classmethod
25
+ def from_string(cls, level: str) -> "LogLevel":
26
+ """Return a LogLevel member from a case-insensitive string.
27
+
28
+ Args:
29
+ level: Name of the log level (case-insensitive).
30
+
31
+ Returns:
32
+ The corresponding :class:`LogLevel` enum member.
33
+ """
34
+ return cls[level.upper()]
35
+
36
+
37
+ def log_level(level: int | str | LogLevel) -> None:
38
+ """Change the logging level for the root logger.
39
+
40
+ Args:
41
+ level: The logging level to set. May be a :class:`LogLevel` enum,
42
+ a level name (str, case-insensitive), or an integer logging level.
43
+
44
+ Returns:
45
+ None
46
+ """
47
+ if isinstance(level, str):
48
+ level = LogLevel.from_string(level)
49
+
50
+ numeric_level = level.value if isinstance(level, LogLevel) else level
51
+
52
+ logger = logging.getLogger()
53
+ logger.setLevel(numeric_level)
54
+ for handler in logger.handlers:
55
+ handler.setLevel(numeric_level)
56
+
57
+
58
+ log_level_by_name = log_level # Alias for convenience
59
+
60
+
61
+ def suppress_warnings() -> None:
62
+ """Suppress all Python warnings.
63
+
64
+ This will disable warning output by filtering all warnings and
65
+ disabling logging's capture of warnings.
66
+
67
+ Returns:
68
+ None
69
+ """
70
+ logging.captureWarnings(False)
71
+ warnings.filterwarnings("ignore")
72
+
73
+
74
+ def enable_warnings() -> None:
75
+ """Enable Python warnings (reset to default behavior).
76
+
77
+ This re-enables warning capture and resets any filters previously set.
78
+
79
+ Returns:
80
+ None
81
+ """
82
+ logging.captureWarnings(True)
83
+ warnings.resetwarnings()
84
+
85
+
86
+ def suppress_specific_warnings(category: type[Warning]) -> None:
87
+ """Suppress warnings of a specific category.
88
+
89
+ Args:
90
+ category: Warning class/type to suppress (for example, :class:`DeprecationWarning`).
91
+
92
+ Returns:
93
+ None
94
+ """
95
+ warnings.filterwarnings("ignore", category=category)
96
+
97
+
98
+ @contextmanager
99
+ def warnings_suppressed() -> Generator:
100
+ """Context manager that temporarily suppresses all warnings.
101
+
102
+ Yields:
103
+ None: Warnings are suppressed inside the context block.
104
+ """
105
+ logging.captureWarnings(False)
106
+ warnings.filterwarnings("ignore")
107
+ try:
108
+ yield
109
+ finally:
110
+ logging.captureWarnings(True)
111
+ warnings.resetwarnings()
112
+
113
+
114
+ def prepare_logger(log_config_filename: str) -> dict:
115
+ """Prepare and configure logging from a YAML file.
116
+
117
+ This function locates or creates a logging configuration YAML file using
118
+ :func:`recnexteval.utils.yaml_tool.create_config_yaml`, ensures the
119
+ directory for the configured log file exists, writes an ASCII art header
120
+ to the log file, and configures the Python logging system using
121
+ :func:`logging.config.dictConfig`.
122
+
123
+ Args:
124
+ log_config_filename: Name of the logging configuration YAML file.
125
+
126
+ Returns:
127
+ dict: The parsed logging configuration dictionary.
128
+
129
+ Raises:
130
+ FileNotFoundError: If the resolved YAML configuration file cannot be found.
131
+ ValueError: If there is an error parsing the YAML content.
132
+ """
133
+ _, yaml_file_path = create_config_yaml(log_config_filename)
134
+ try:
135
+ with open(yaml_file_path, "r") as stream:
136
+ config = yaml.load(
137
+ stream, Loader=yaml.FullLoader
138
+ )
139
+ except FileNotFoundError:
140
+ raise FileNotFoundError(f"Configuration file not found at {yaml_file_path}.")
141
+ except yaml.YAMLError as e:
142
+ raise ValueError(f"Error parsing YAML configuration: {e}")
143
+
144
+ # Get the log file path from the configuration
145
+ log_file = config["handlers"]["file"]["filename"]
146
+
147
+ # Ensure the log file directory exists
148
+ dir_name = os.path.dirname(log_file)
149
+ safe_dir(dir_name)
150
+
151
+ # Write ASCII art to the log file
152
+ with open(log_file, "w") as log:
153
+ ascii_art = pyfiglet.figlet_format("recnexteval")
154
+ log.write(ascii_art)
155
+ log.write("\n")
156
+
157
+ logging.config.dictConfig(config)
158
+ logging.captureWarnings(True)
159
+ return config
@@ -0,0 +1,155 @@
1
+ """Path utilities for recnexteval library.
2
+
3
+ This module provides functions to resolve paths relative to the repository root,
4
+ ensuring that data and logs are stored consistently regardless of where the
5
+ library is run from.
6
+ """
7
+
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+
13
+ _REPO_ROOT_CACHE: Optional[Path] = None
14
+
15
+
16
+ def get_repo_root(
17
+ marker_files: tuple[str, ...] = (
18
+ ".git",
19
+ "pyproject.toml",
20
+ "setup.py",
21
+ "setup.cfg",
22
+ "README.md",
23
+ "requirements.txt",
24
+ ),
25
+ ) -> Path:
26
+ """Find and return the repository root directory.
27
+
28
+ The function searches upward from the current file's directory looking for
29
+ common repository marker files (for example, `.git` or `pyproject.toml`).
30
+ The result is cached to avoid repeated filesystem traversal.
31
+
32
+ Args:
33
+ marker_files: Tuple of filenames that indicate the repository root.
34
+
35
+ Returns:
36
+ Path to the repository root directory.
37
+
38
+ Raises:
39
+ RuntimeError: If the repository root cannot be located and the
40
+ `STREAMSIGHT_ROOT` environment variable is not set or invalid.
41
+ """
42
+ global _REPO_ROOT_CACHE
43
+
44
+ if _REPO_ROOT_CACHE is not None:
45
+ return _REPO_ROOT_CACHE
46
+
47
+ # Start from this file's location
48
+ current = Path(__file__).resolve().parent
49
+ max_depth = 10
50
+
51
+ for _ in range(max_depth):
52
+ # Check for marker files
53
+ if any((current / marker).exists() for marker in marker_files):
54
+ _REPO_ROOT_CACHE = current
55
+ return current
56
+
57
+ # Move up one directory
58
+ parent = current.parent
59
+ if parent == current: # Reached filesystem root
60
+ break
61
+ current = parent
62
+
63
+ # Fallback: try environment variable
64
+ if "STREAMSIGHT_ROOT" in os.environ:
65
+ root = Path(os.environ["STREAMSIGHT_ROOT"])
66
+ if root.exists():
67
+ _REPO_ROOT_CACHE = root
68
+ return root
69
+
70
+ raise RuntimeError(
71
+ "Could not find repository root. Please set STREAMSIGHT_ROOT "
72
+ "environment variable or ensure you're running from within the repo."
73
+ )
74
+
75
+
76
+ def get_data_dir(subdir: str = "") -> Path:
77
+ """Return the `data/` directory inside the repository.
78
+
79
+ Args:
80
+ subdir: Optional subdirectory within `data/` to append.
81
+
82
+ Returns:
83
+ Path to the data directory (with `subdir` appended when provided).
84
+ """
85
+ data_dir = get_repo_root() / "data"
86
+ if subdir:
87
+ data_dir = data_dir / subdir
88
+ return data_dir
89
+
90
+
91
+ def get_logs_dir(subdir: str = "") -> Path:
92
+ """Return the `logs/` directory inside the repository.
93
+
94
+ Args:
95
+ subdir: Optional subdirectory within `logs/` to append.
96
+
97
+ Returns:
98
+ Path to the logs directory (with `subdir` appended when provided).
99
+ """
100
+ logs_dir = get_repo_root() / "logs"
101
+ if subdir:
102
+ logs_dir = logs_dir / subdir
103
+ return logs_dir
104
+
105
+
106
+ def get_cache_dir(subdir: str = "") -> Path:
107
+ """Return the `cache/` directory inside the repository.
108
+
109
+ Args:
110
+ subdir: Optional subdirectory within `cache/` to append.
111
+
112
+ Returns:
113
+ Path to the cache directory (with `subdir` appended when provided).
114
+ """
115
+ cache_dir = get_repo_root() / "cache"
116
+ if subdir:
117
+ cache_dir = cache_dir / subdir
118
+ return cache_dir
119
+
120
+
121
+ def safe_dir(path: Path | str) -> Path:
122
+ """Ensure the given directory exists, creating it if necessary.
123
+
124
+ Args:
125
+ path: Directory path as a :class:`pathlib.Path` or string.
126
+
127
+ Returns:
128
+ The directory path (created if it did not exist).
129
+ """
130
+ path = Path(path)
131
+ path.mkdir(parents=True, exist_ok=True)
132
+ return path
133
+
134
+
135
+ def resolve_path(path: str | Path, relative_to_root: bool = True) -> Path:
136
+ """Resolve a path to an absolute :class:`pathlib.Path`.
137
+
138
+ Args:
139
+ path: Path to resolve, either a string or :class:`pathlib.Path`.
140
+ relative_to_root: If True and `path` is relative, resolve it relative to
141
+ the repository root. If False, resolve it relative to the current
142
+ working directory.
143
+
144
+ Returns:
145
+ The resolved absolute path.
146
+ """
147
+ path = Path(path)
148
+
149
+ if path.is_absolute():
150
+ return path
151
+
152
+ if relative_to_root:
153
+ return get_repo_root() / path
154
+
155
+ return path.resolve()
@@ -0,0 +1,54 @@
1
+ # install_certifi.py
2
+ #
3
+ # sample script to install or update a set of default Root Certificates
4
+ # for the ssl module. Uses the certificates provided by the certifi package:
5
+ # https://pypi.python.org/pypi/certifi
6
+
7
+ # solution is obtained from https://stackoverflow.com/questions/44649449/brew-installation-of-python-3-6-1-ssl-certificate-verify-failed-certificate/49953648#49953648
8
+
9
+ import contextlib
10
+ import os
11
+ import os.path
12
+ import ssl
13
+ import stat
14
+ import subprocess
15
+ import sys
16
+
17
+
18
+ STAT_0o775 = (
19
+ stat.S_IRUSR
20
+ | stat.S_IWUSR
21
+ | stat.S_IXUSR
22
+ | stat.S_IRGRP
23
+ | stat.S_IWGRP
24
+ | stat.S_IXGRP
25
+ | stat.S_IROTH
26
+ | stat.S_IXOTH
27
+ )
28
+
29
+
30
+ def main():
31
+ openssl_dir, openssl_cafile = os.path.split(ssl.get_default_verify_paths().openssl_cafile)
32
+
33
+ print(" -- pip install --upgrade certifi")
34
+ subprocess.check_call(
35
+ [sys.executable, "-E", "-s", "-m", "pip", "install", "--upgrade", "certifi"]
36
+ )
37
+
38
+ import certifi
39
+
40
+ # change working directory to the default SSL directory
41
+ os.chdir(openssl_dir)
42
+ relpath_to_certifi_cafile = os.path.relpath(certifi.where())
43
+ print(" -- removing any existing file or link")
44
+ with contextlib.suppress(FileNotFoundError):
45
+ os.remove(openssl_cafile)
46
+ print(" -- creating symlink to certifi certificate bundle")
47
+ os.symlink(relpath_to_certifi_cafile, openssl_cafile)
48
+ print(" -- setting permissions")
49
+ os.chmod(openssl_cafile, STAT_0o775)
50
+ print(" -- update complete")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
@@ -0,0 +1,166 @@
1
+ import logging
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import progressbar
6
+ from scipy.sparse import csr_matrix
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def to_tuple(element) -> tuple:
13
+ """Whether single element or tuple, always returns as tuple."""
14
+ if isinstance(element, tuple):
15
+ return element
16
+ else:
17
+ return (element,)
18
+
19
+
20
+ def arg_to_str(arg: type | str) -> str:
21
+ """Converts a type to its name or returns the string.
22
+
23
+ :param arg: Argument to convert to string.
24
+ :type arg: Union[type, str]
25
+ :return: String representation of the argument.
26
+ :rtype: str
27
+ :raises TypeError: If the argument is not a string or a type.
28
+ """
29
+ if isinstance(arg, type):
30
+ return arg.__name__
31
+ elif not isinstance(arg, str):
32
+ raise TypeError(f"Argument should be string or type, not {type(arg)}!")
33
+ return arg
34
+
35
+
36
+ def df_to_sparse(
37
+ df,
38
+ item_ix,
39
+ user_ix,
40
+ value_ix=None,
41
+ shape=None,
42
+ ) -> csr_matrix:
43
+ if value_ix is not None and value_ix in df:
44
+ values = df[value_ix]
45
+ else:
46
+ if value_ix is not None:
47
+ # value_ix provided, but not in df
48
+ logger.warning(f"Value column {value_ix} not found in dataframe. Using ones instead.")
49
+
50
+ num_entries = df.shape[0]
51
+ # Scipy sums up the entries when an index-pair occurs more than once,
52
+ # resulting in the actual counts being stored. Neat!
53
+ values = np.ones(num_entries)
54
+
55
+ indices = list(zip(*df.loc[:, [user_ix, item_ix]].values))
56
+
57
+ if indices == []:
58
+ indices = [[], []] # Empty zip does not evaluate right
59
+
60
+ if shape is None:
61
+ shape = df[user_ix].max() + 1, df[item_ix].max() + 1
62
+ sparse_matrix = csr_matrix((values, indices), shape=shape, dtype=values.dtype)
63
+
64
+ return sparse_matrix
65
+
66
+
67
+ def to_binary(X: csr_matrix) -> csr_matrix:
68
+ """Converts a matrix to binary by setting all non-zero values to 1.
69
+
70
+ :param X: Matrix to convert to binary.
71
+ :type X: csr_matrix
72
+ :return: Binary matrix.
73
+ :rtype: csr_matrix
74
+ """
75
+ X_binary = X.astype(bool).astype(X.dtype)
76
+
77
+ return X_binary
78
+
79
+
80
+ def invert(x: Union[np.ndarray, csr_matrix]) -> Union[np.ndarray, csr_matrix]:
81
+ """Invert an array.
82
+
83
+ :param x: [description]
84
+ :type x: [type]
85
+ :return: [description]
86
+ :rtype: [type]
87
+ """
88
+ if isinstance(x, np.ndarray):
89
+ ret = np.zeros(x.shape)
90
+ elif isinstance(x, csr_matrix):
91
+ ret = csr_matrix(x.shape)
92
+ else:
93
+ raise TypeError("Unsupported type for argument x.")
94
+ ret[x.nonzero()] = 1 / x[x.nonzero()]
95
+ return ret
96
+
97
+
98
+ class ProgressBar:
99
+ """Progress bar as visual."""
100
+
101
+ def __init__(self) -> None:
102
+ self.pbar = None
103
+
104
+ def __call__(self, block_num, block_size, total_size) -> None:
105
+ if not self.pbar:
106
+ self.pbar = progressbar.ProgressBar(maxval=total_size)
107
+ self.pbar.start()
108
+
109
+ downloaded = block_num * block_size
110
+ if downloaded < total_size:
111
+ self.pbar.update(downloaded)
112
+ else:
113
+ self.pbar.finish()
114
+
115
+
116
+ def add_rows_to_csr_matrix(matrix: csr_matrix, n: int = 1) -> csr_matrix:
117
+ """Add a row of zeros to a csr_matrix.
118
+
119
+ ref: https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr
120
+
121
+ :param matrix: Matrix to add a row of zeros to.
122
+ :type matrix: csr_matrix
123
+ :return: Matrix with a row of zeros added.
124
+ :rtype: csr_matrix
125
+ """
126
+ new_shape = (matrix.shape[0] + n, matrix.shape[1])
127
+ new_indptr = np.append(matrix.indptr, [matrix.indptr[-1]] * n)
128
+ matrix = csr_matrix((matrix.data, matrix.indices, new_indptr), shape=new_shape, copy=False)
129
+ return matrix
130
+
131
+
132
+ def add_columns_to_csr_matrix(matrix: csr_matrix, n: int = 1) -> csr_matrix:
133
+ """Add a column of zeros to a csr_matrix.
134
+
135
+ https://stackoverflow.com/questions/30691160/effectively-change-dimension-of-scipy-spare-csr-matrix
136
+
137
+ :param matrix: Matrix to add a column of zeros to.
138
+ :type matrix: csr_matrix
139
+ :return: Matrix with a column of zeros added.
140
+ :rtype: csr_matrix
141
+ """
142
+ new_shape = (matrix.shape[0], matrix.shape[1] + n)
143
+ matrix = csr_matrix(
144
+ (matrix.data, matrix.indices, matrix.indptr),
145
+ shape=new_shape,
146
+ copy=False,
147
+ )
148
+ return matrix
149
+
150
+
151
+ def set_row_csr_addition(A: csr_matrix, row_idx: int, new_row: np.ndarray) -> None:
152
+ """Set row of a csr_matrix to a new row.
153
+
154
+ ref: https://stackoverflow.com/questions/28427236/set-row-of-csr-matrix
155
+
156
+ :param A: Matrix to set a row of.
157
+ :type A: csr_matrix
158
+ :param row_idx: Index of the row to set.
159
+ :type row_idx: int
160
+ :param new_row: New row to set.
161
+ :type new_row: np.ndarray
162
+ """
163
+ indptr = np.zeros(A.shape[1] + 1)
164
+ indptr[row_idx + 1 :] = A.shape[1]
165
+ indices = np.arange(A.shape[1])
166
+ A += csr_matrix((new_row, indices, indptr), shape=A.shape)
@@ -0,0 +1,7 @@
1
+ import datetime
2
+ from uuid import NAMESPACE_DNS, UUID, uuid5
3
+
4
+
5
+ def generate_algorithm_uuid(name: str) -> UUID:
6
+ """Generate a UUID for an algorithm based on its name and current timestamp."""
7
+ return uuid5(NAMESPACE_DNS, f"{name}{datetime.datetime.now()}")