recnexteval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recnexteval/__init__.py +20 -0
- recnexteval/algorithms/__init__.py +99 -0
- recnexteval/algorithms/base.py +377 -0
- recnexteval/algorithms/baseline/__init__.py +10 -0
- recnexteval/algorithms/baseline/decay_popularity.py +110 -0
- recnexteval/algorithms/baseline/most_popular.py +72 -0
- recnexteval/algorithms/baseline/random.py +39 -0
- recnexteval/algorithms/baseline/recent_popularity.py +34 -0
- recnexteval/algorithms/itemknn/__init__.py +14 -0
- recnexteval/algorithms/itemknn/itemknn.py +119 -0
- recnexteval/algorithms/itemknn/itemknn_incremental.py +65 -0
- recnexteval/algorithms/itemknn/itemknn_incremental_movielens.py +95 -0
- recnexteval/algorithms/itemknn/itemknn_rolling.py +17 -0
- recnexteval/algorithms/itemknn/itemknn_static.py +31 -0
- recnexteval/algorithms/time_aware_item_knn/__init__.py +11 -0
- recnexteval/algorithms/time_aware_item_knn/base.py +248 -0
- recnexteval/algorithms/time_aware_item_knn/decay_functions.py +260 -0
- recnexteval/algorithms/time_aware_item_knn/ding_2005.py +52 -0
- recnexteval/algorithms/time_aware_item_knn/liu_2010.py +65 -0
- recnexteval/algorithms/time_aware_item_knn/similarity_functions.py +106 -0
- recnexteval/algorithms/time_aware_item_knn/top_k.py +61 -0
- recnexteval/algorithms/time_aware_item_knn/utils.py +47 -0
- recnexteval/algorithms/time_aware_item_knn/vaz_2013.py +50 -0
- recnexteval/algorithms/utils.py +51 -0
- recnexteval/datasets/__init__.py +109 -0
- recnexteval/datasets/base.py +316 -0
- recnexteval/datasets/config/__init__.py +113 -0
- recnexteval/datasets/config/amazon.py +188 -0
- recnexteval/datasets/config/base.py +72 -0
- recnexteval/datasets/config/lastfm.py +105 -0
- recnexteval/datasets/config/movielens.py +169 -0
- recnexteval/datasets/config/yelp.py +25 -0
- recnexteval/datasets/datasets/__init__.py +24 -0
- recnexteval/datasets/datasets/amazon.py +151 -0
- recnexteval/datasets/datasets/base.py +250 -0
- recnexteval/datasets/datasets/lastfm.py +121 -0
- recnexteval/datasets/datasets/movielens.py +93 -0
- recnexteval/datasets/datasets/test.py +46 -0
- recnexteval/datasets/datasets/yelp.py +103 -0
- recnexteval/datasets/metadata/__init__.py +58 -0
- recnexteval/datasets/metadata/amazon.py +68 -0
- recnexteval/datasets/metadata/base.py +38 -0
- recnexteval/datasets/metadata/lastfm.py +110 -0
- recnexteval/datasets/metadata/movielens.py +87 -0
- recnexteval/evaluators/__init__.py +189 -0
- recnexteval/evaluators/accumulator.py +167 -0
- recnexteval/evaluators/base.py +216 -0
- recnexteval/evaluators/builder/__init__.py +125 -0
- recnexteval/evaluators/builder/base.py +166 -0
- recnexteval/evaluators/builder/pipeline.py +111 -0
- recnexteval/evaluators/builder/stream.py +54 -0
- recnexteval/evaluators/evaluator_pipeline.py +287 -0
- recnexteval/evaluators/evaluator_stream.py +374 -0
- recnexteval/evaluators/state_management.py +310 -0
- recnexteval/evaluators/strategy.py +32 -0
- recnexteval/evaluators/util.py +124 -0
- recnexteval/matrix/__init__.py +48 -0
- recnexteval/matrix/exception.py +5 -0
- recnexteval/matrix/interaction_matrix.py +784 -0
- recnexteval/matrix/prediction_matrix.py +153 -0
- recnexteval/matrix/util.py +24 -0
- recnexteval/metrics/__init__.py +57 -0
- recnexteval/metrics/binary/__init__.py +4 -0
- recnexteval/metrics/binary/hit.py +49 -0
- recnexteval/metrics/core/__init__.py +10 -0
- recnexteval/metrics/core/base.py +126 -0
- recnexteval/metrics/core/elementwise_top_k.py +75 -0
- recnexteval/metrics/core/listwise_top_k.py +72 -0
- recnexteval/metrics/core/top_k.py +60 -0
- recnexteval/metrics/core/util.py +29 -0
- recnexteval/metrics/ranking/__init__.py +6 -0
- recnexteval/metrics/ranking/dcg.py +55 -0
- recnexteval/metrics/ranking/ndcg.py +78 -0
- recnexteval/metrics/ranking/precision.py +51 -0
- recnexteval/metrics/ranking/recall.py +42 -0
- recnexteval/models/__init__.py +4 -0
- recnexteval/models/base.py +69 -0
- recnexteval/preprocessing/__init__.py +37 -0
- recnexteval/preprocessing/filter.py +181 -0
- recnexteval/preprocessing/preprocessor.py +137 -0
- recnexteval/registries/__init__.py +67 -0
- recnexteval/registries/algorithm.py +68 -0
- recnexteval/registries/base.py +131 -0
- recnexteval/registries/dataset.py +37 -0
- recnexteval/registries/metric.py +57 -0
- recnexteval/settings/__init__.py +127 -0
- recnexteval/settings/base.py +414 -0
- recnexteval/settings/exception.py +8 -0
- recnexteval/settings/leave_n_out_setting.py +48 -0
- recnexteval/settings/processor.py +115 -0
- recnexteval/settings/schema.py +11 -0
- recnexteval/settings/single_time_point_setting.py +111 -0
- recnexteval/settings/sliding_window_setting.py +153 -0
- recnexteval/settings/splitters/__init__.py +14 -0
- recnexteval/settings/splitters/base.py +57 -0
- recnexteval/settings/splitters/n_last.py +39 -0
- recnexteval/settings/splitters/n_last_timestamp.py +76 -0
- recnexteval/settings/splitters/timestamp.py +82 -0
- recnexteval/settings/util.py +0 -0
- recnexteval/utils/__init__.py +115 -0
- recnexteval/utils/json_to_csv_converter.py +128 -0
- recnexteval/utils/logging_tools.py +159 -0
- recnexteval/utils/path.py +155 -0
- recnexteval/utils/url_certificate_installer.py +54 -0
- recnexteval/utils/util.py +166 -0
- recnexteval/utils/uuid_util.py +7 -0
- recnexteval/utils/yaml_tool.py +65 -0
- recnexteval-0.1.0.dist-info/METADATA +85 -0
- recnexteval-0.1.0.dist-info/RECORD +110 -0
- recnexteval-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# Reference: https://github.com/Yelp/dataset-examples/blob/master/json_to_csv_converter.py
|
|
2
|
+
|
|
3
|
+
# -*- coding: utf-8 -*-
|
|
4
|
+
"""Convert the Yelp Dataset Challenge dataset from json format to csv.
|
|
5
|
+
|
|
6
|
+
For more information on the Yelp Dataset Challenge please visit http://yelp.com/dataset_challenge
|
|
7
|
+
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import argparse
|
|
11
|
+
import csv
|
|
12
|
+
from collections.abc import MutableMapping
|
|
13
|
+
|
|
14
|
+
import simplejson as json
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def read_and_write_file(json_file_path, csv_file_path, column_names) -> None:
|
|
18
|
+
"""Read in the json dataset file and write it out to a csv file, given the column names."""
|
|
19
|
+
with open(csv_file_path, "w") as fout:
|
|
20
|
+
csv_file = csv.writer(fout)
|
|
21
|
+
csv_file.writerow(list(column_names))
|
|
22
|
+
with open(json_file_path, encoding="utf-8") as fin:
|
|
23
|
+
for line in fin:
|
|
24
|
+
line_contents = json.loads(line)
|
|
25
|
+
csv_file.writerow(get_row(line_contents, column_names))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_superset_of_column_names_from_file(json_file_path):
|
|
29
|
+
"""Read in the json dataset file and return the superset of column names."""
|
|
30
|
+
column_names = set()
|
|
31
|
+
with open(json_file_path, encoding="utf-8") as fin:
|
|
32
|
+
for line in fin:
|
|
33
|
+
line_contents = json.loads(line)
|
|
34
|
+
column_names.update(set(get_column_names(line_contents).keys()))
|
|
35
|
+
return column_names
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_column_names(line_contents, parent_key="") -> dict:
|
|
39
|
+
"""Return a list of flattened key names given a dict.
|
|
40
|
+
|
|
41
|
+
Example:
|
|
42
|
+
|
|
43
|
+
line_contents = {
|
|
44
|
+
'a': {
|
|
45
|
+
'b': 2,
|
|
46
|
+
'c': 3,
|
|
47
|
+
},
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
will return: ['a.b', 'a.c']
|
|
51
|
+
|
|
52
|
+
These will be the column names for the eventual csv file.
|
|
53
|
+
|
|
54
|
+
"""
|
|
55
|
+
column_names = []
|
|
56
|
+
for k, v in line_contents.items():
|
|
57
|
+
column_name = "{0}.{1}".format(parent_key, k) if parent_key else k
|
|
58
|
+
if isinstance(v, MutableMapping):
|
|
59
|
+
column_names.extend(get_column_names(v, column_name).items())
|
|
60
|
+
else:
|
|
61
|
+
column_names.append((column_name, v))
|
|
62
|
+
return dict(column_names)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_nested_value(d, key):
|
|
66
|
+
"""Return a dictionary item given a dictionary `d` and a flattened key from `get_column_names`.
|
|
67
|
+
|
|
68
|
+
Example:
|
|
69
|
+
|
|
70
|
+
d = {
|
|
71
|
+
'a': {
|
|
72
|
+
'b': 2,
|
|
73
|
+
'c': 3,
|
|
74
|
+
},
|
|
75
|
+
}
|
|
76
|
+
key = 'a.b'
|
|
77
|
+
|
|
78
|
+
will return: 2
|
|
79
|
+
|
|
80
|
+
"""
|
|
81
|
+
if "." not in key:
|
|
82
|
+
if key not in d:
|
|
83
|
+
return None
|
|
84
|
+
return d[key]
|
|
85
|
+
base_key, sub_key = key.split(".", 1)
|
|
86
|
+
if base_key not in d:
|
|
87
|
+
return None
|
|
88
|
+
sub_dict = d[base_key]
|
|
89
|
+
return get_nested_value(sub_dict, sub_key)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_row(line_contents, column_names):
|
|
93
|
+
"""Return a csv compatible row given column names and a dict."""
|
|
94
|
+
row = []
|
|
95
|
+
for column_name in column_names:
|
|
96
|
+
line_value = get_nested_value(
|
|
97
|
+
line_contents,
|
|
98
|
+
column_name,
|
|
99
|
+
)
|
|
100
|
+
if isinstance(line_value, str):
|
|
101
|
+
row.append("{0}".format(line_value.encode("utf-8")))
|
|
102
|
+
elif line_value is not None:
|
|
103
|
+
row.append("{0}".format(line_value))
|
|
104
|
+
else:
|
|
105
|
+
row.append("")
|
|
106
|
+
return row
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
if __name__ == "__main__":
|
|
110
|
+
"""Convert a yelp dataset file from json to csv."""
|
|
111
|
+
|
|
112
|
+
parser = argparse.ArgumentParser(
|
|
113
|
+
description="Convert Yelp Dataset Challenge data from JSON format to CSV.",
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
parser.add_argument(
|
|
117
|
+
"json_file",
|
|
118
|
+
type=str,
|
|
119
|
+
help="The json file to convert.",
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
args = parser.parse_args()
|
|
123
|
+
|
|
124
|
+
json_file = args.json_file
|
|
125
|
+
csv_file = "{0}.csv".format(json_file.split(".json")[0])
|
|
126
|
+
|
|
127
|
+
column_names = get_superset_of_column_names_from_file(json_file)
|
|
128
|
+
read_and_write_file(json_file, csv_file, column_names)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import logging.config
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from enum import Enum
|
|
8
|
+
|
|
9
|
+
import pyfiglet
|
|
10
|
+
import yaml
|
|
11
|
+
|
|
12
|
+
from recnexteval.utils.path import safe_dir
|
|
13
|
+
from recnexteval.utils.yaml_tool import create_config_yaml
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LogLevel(Enum):
|
|
17
|
+
CRITICAL = logging.CRITICAL
|
|
18
|
+
ERROR = logging.ERROR
|
|
19
|
+
WARNING = logging.WARNING
|
|
20
|
+
INFO = logging.INFO
|
|
21
|
+
DEBUG = logging.DEBUG
|
|
22
|
+
NOTSET = logging.NOTSET
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def from_string(cls, level: str) -> "LogLevel":
|
|
26
|
+
"""Return a LogLevel member from a case-insensitive string.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
level: Name of the log level (case-insensitive).
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The corresponding :class:`LogLevel` enum member.
|
|
33
|
+
"""
|
|
34
|
+
return cls[level.upper()]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def log_level(level: int | str | LogLevel) -> None:
|
|
38
|
+
"""Change the logging level for the root logger.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
level: The logging level to set. May be a :class:`LogLevel` enum,
|
|
42
|
+
a level name (str, case-insensitive), or an integer logging level.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
None
|
|
46
|
+
"""
|
|
47
|
+
if isinstance(level, str):
|
|
48
|
+
level = LogLevel.from_string(level)
|
|
49
|
+
|
|
50
|
+
numeric_level = level.value if isinstance(level, LogLevel) else level
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger()
|
|
53
|
+
logger.setLevel(numeric_level)
|
|
54
|
+
for handler in logger.handlers:
|
|
55
|
+
handler.setLevel(numeric_level)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
log_level_by_name = log_level # Alias for convenience
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def suppress_warnings() -> None:
|
|
62
|
+
"""Suppress all Python warnings.
|
|
63
|
+
|
|
64
|
+
This will disable warning output by filtering all warnings and
|
|
65
|
+
disabling logging's capture of warnings.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
None
|
|
69
|
+
"""
|
|
70
|
+
logging.captureWarnings(False)
|
|
71
|
+
warnings.filterwarnings("ignore")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def enable_warnings() -> None:
|
|
75
|
+
"""Enable Python warnings (reset to default behavior).
|
|
76
|
+
|
|
77
|
+
This re-enables warning capture and resets any filters previously set.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
None
|
|
81
|
+
"""
|
|
82
|
+
logging.captureWarnings(True)
|
|
83
|
+
warnings.resetwarnings()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def suppress_specific_warnings(category: type[Warning]) -> None:
|
|
87
|
+
"""Suppress warnings of a specific category.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
category: Warning class/type to suppress (for example, :class:`DeprecationWarning`).
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
None
|
|
94
|
+
"""
|
|
95
|
+
warnings.filterwarnings("ignore", category=category)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@contextmanager
|
|
99
|
+
def warnings_suppressed() -> Generator:
|
|
100
|
+
"""Context manager that temporarily suppresses all warnings.
|
|
101
|
+
|
|
102
|
+
Yields:
|
|
103
|
+
None: Warnings are suppressed inside the context block.
|
|
104
|
+
"""
|
|
105
|
+
logging.captureWarnings(False)
|
|
106
|
+
warnings.filterwarnings("ignore")
|
|
107
|
+
try:
|
|
108
|
+
yield
|
|
109
|
+
finally:
|
|
110
|
+
logging.captureWarnings(True)
|
|
111
|
+
warnings.resetwarnings()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def prepare_logger(log_config_filename: str) -> dict:
|
|
115
|
+
"""Prepare and configure logging from a YAML file.
|
|
116
|
+
|
|
117
|
+
This function locates or creates a logging configuration YAML file using
|
|
118
|
+
:func:`recnexteval.utils.yaml_tool.create_config_yaml`, ensures the
|
|
119
|
+
directory for the configured log file exists, writes an ASCII art header
|
|
120
|
+
to the log file, and configures the Python logging system using
|
|
121
|
+
:func:`logging.config.dictConfig`.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
log_config_filename: Name of the logging configuration YAML file.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
dict: The parsed logging configuration dictionary.
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
FileNotFoundError: If the resolved YAML configuration file cannot be found.
|
|
131
|
+
ValueError: If there is an error parsing the YAML content.
|
|
132
|
+
"""
|
|
133
|
+
_, yaml_file_path = create_config_yaml(log_config_filename)
|
|
134
|
+
try:
|
|
135
|
+
with open(yaml_file_path, "r") as stream:
|
|
136
|
+
config = yaml.load(
|
|
137
|
+
stream, Loader=yaml.FullLoader
|
|
138
|
+
)
|
|
139
|
+
except FileNotFoundError:
|
|
140
|
+
raise FileNotFoundError(f"Configuration file not found at {yaml_file_path}.")
|
|
141
|
+
except yaml.YAMLError as e:
|
|
142
|
+
raise ValueError(f"Error parsing YAML configuration: {e}")
|
|
143
|
+
|
|
144
|
+
# Get the log file path from the configuration
|
|
145
|
+
log_file = config["handlers"]["file"]["filename"]
|
|
146
|
+
|
|
147
|
+
# Ensure the log file directory exists
|
|
148
|
+
dir_name = os.path.dirname(log_file)
|
|
149
|
+
safe_dir(dir_name)
|
|
150
|
+
|
|
151
|
+
# Write ASCII art to the log file
|
|
152
|
+
with open(log_file, "w") as log:
|
|
153
|
+
ascii_art = pyfiglet.figlet_format("recnexteval")
|
|
154
|
+
log.write(ascii_art)
|
|
155
|
+
log.write("\n")
|
|
156
|
+
|
|
157
|
+
logging.config.dictConfig(config)
|
|
158
|
+
logging.captureWarnings(True)
|
|
159
|
+
return config
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Path utilities for recnexteval library.
|
|
2
|
+
|
|
3
|
+
This module provides functions to resolve paths relative to the repository root,
|
|
4
|
+
ensuring that data and logs are stored consistently regardless of where the
|
|
5
|
+
library is run from.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
_REPO_ROOT_CACHE: Optional[Path] = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_repo_root(
|
|
17
|
+
marker_files: tuple[str, ...] = (
|
|
18
|
+
".git",
|
|
19
|
+
"pyproject.toml",
|
|
20
|
+
"setup.py",
|
|
21
|
+
"setup.cfg",
|
|
22
|
+
"README.md",
|
|
23
|
+
"requirements.txt",
|
|
24
|
+
),
|
|
25
|
+
) -> Path:
|
|
26
|
+
"""Find and return the repository root directory.
|
|
27
|
+
|
|
28
|
+
The function searches upward from the current file's directory looking for
|
|
29
|
+
common repository marker files (for example, `.git` or `pyproject.toml`).
|
|
30
|
+
The result is cached to avoid repeated filesystem traversal.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
marker_files: Tuple of filenames that indicate the repository root.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Path to the repository root directory.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
RuntimeError: If the repository root cannot be located and the
|
|
40
|
+
`STREAMSIGHT_ROOT` environment variable is not set or invalid.
|
|
41
|
+
"""
|
|
42
|
+
global _REPO_ROOT_CACHE
|
|
43
|
+
|
|
44
|
+
if _REPO_ROOT_CACHE is not None:
|
|
45
|
+
return _REPO_ROOT_CACHE
|
|
46
|
+
|
|
47
|
+
# Start from this file's location
|
|
48
|
+
current = Path(__file__).resolve().parent
|
|
49
|
+
max_depth = 10
|
|
50
|
+
|
|
51
|
+
for _ in range(max_depth):
|
|
52
|
+
# Check for marker files
|
|
53
|
+
if any((current / marker).exists() for marker in marker_files):
|
|
54
|
+
_REPO_ROOT_CACHE = current
|
|
55
|
+
return current
|
|
56
|
+
|
|
57
|
+
# Move up one directory
|
|
58
|
+
parent = current.parent
|
|
59
|
+
if parent == current: # Reached filesystem root
|
|
60
|
+
break
|
|
61
|
+
current = parent
|
|
62
|
+
|
|
63
|
+
# Fallback: try environment variable
|
|
64
|
+
if "STREAMSIGHT_ROOT" in os.environ:
|
|
65
|
+
root = Path(os.environ["STREAMSIGHT_ROOT"])
|
|
66
|
+
if root.exists():
|
|
67
|
+
_REPO_ROOT_CACHE = root
|
|
68
|
+
return root
|
|
69
|
+
|
|
70
|
+
raise RuntimeError(
|
|
71
|
+
"Could not find repository root. Please set STREAMSIGHT_ROOT "
|
|
72
|
+
"environment variable or ensure you're running from within the repo."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_data_dir(subdir: str = "") -> Path:
|
|
77
|
+
"""Return the `data/` directory inside the repository.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
subdir: Optional subdirectory within `data/` to append.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Path to the data directory (with `subdir` appended when provided).
|
|
84
|
+
"""
|
|
85
|
+
data_dir = get_repo_root() / "data"
|
|
86
|
+
if subdir:
|
|
87
|
+
data_dir = data_dir / subdir
|
|
88
|
+
return data_dir
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_logs_dir(subdir: str = "") -> Path:
|
|
92
|
+
"""Return the `logs/` directory inside the repository.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
subdir: Optional subdirectory within `logs/` to append.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Path to the logs directory (with `subdir` appended when provided).
|
|
99
|
+
"""
|
|
100
|
+
logs_dir = get_repo_root() / "logs"
|
|
101
|
+
if subdir:
|
|
102
|
+
logs_dir = logs_dir / subdir
|
|
103
|
+
return logs_dir
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_cache_dir(subdir: str = "") -> Path:
|
|
107
|
+
"""Return the `cache/` directory inside the repository.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
subdir: Optional subdirectory within `cache/` to append.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Path to the cache directory (with `subdir` appended when provided).
|
|
114
|
+
"""
|
|
115
|
+
cache_dir = get_repo_root() / "cache"
|
|
116
|
+
if subdir:
|
|
117
|
+
cache_dir = cache_dir / subdir
|
|
118
|
+
return cache_dir
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def safe_dir(path: Path | str) -> Path:
|
|
122
|
+
"""Ensure the given directory exists, creating it if necessary.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
path: Directory path as a :class:`pathlib.Path` or string.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
The directory path (created if it did not exist).
|
|
129
|
+
"""
|
|
130
|
+
path = Path(path)
|
|
131
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
132
|
+
return path
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def resolve_path(path: str | Path, relative_to_root: bool = True) -> Path:
|
|
136
|
+
"""Resolve a path to an absolute :class:`pathlib.Path`.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
path: Path to resolve, either a string or :class:`pathlib.Path`.
|
|
140
|
+
relative_to_root: If True and `path` is relative, resolve it relative to
|
|
141
|
+
the repository root. If False, resolve it relative to the current
|
|
142
|
+
working directory.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The resolved absolute path.
|
|
146
|
+
"""
|
|
147
|
+
path = Path(path)
|
|
148
|
+
|
|
149
|
+
if path.is_absolute():
|
|
150
|
+
return path
|
|
151
|
+
|
|
152
|
+
if relative_to_root:
|
|
153
|
+
return get_repo_root() / path
|
|
154
|
+
|
|
155
|
+
return path.resolve()
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# install_certifi.py
|
|
2
|
+
#
|
|
3
|
+
# sample script to install or update a set of default Root Certificates
|
|
4
|
+
# for the ssl module. Uses the certificates provided by the certifi package:
|
|
5
|
+
# https://pypi.python.org/pypi/certifi
|
|
6
|
+
|
|
7
|
+
# solution is obtained from https://stackoverflow.com/questions/44649449/brew-installation-of-python-3-6-1-ssl-certificate-verify-failed-certificate/49953648#49953648
|
|
8
|
+
|
|
9
|
+
import contextlib
|
|
10
|
+
import os
|
|
11
|
+
import os.path
|
|
12
|
+
import ssl
|
|
13
|
+
import stat
|
|
14
|
+
import subprocess
|
|
15
|
+
import sys
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
STAT_0o775 = (
|
|
19
|
+
stat.S_IRUSR
|
|
20
|
+
| stat.S_IWUSR
|
|
21
|
+
| stat.S_IXUSR
|
|
22
|
+
| stat.S_IRGRP
|
|
23
|
+
| stat.S_IWGRP
|
|
24
|
+
| stat.S_IXGRP
|
|
25
|
+
| stat.S_IROTH
|
|
26
|
+
| stat.S_IXOTH
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def main():
|
|
31
|
+
openssl_dir, openssl_cafile = os.path.split(ssl.get_default_verify_paths().openssl_cafile)
|
|
32
|
+
|
|
33
|
+
print(" -- pip install --upgrade certifi")
|
|
34
|
+
subprocess.check_call(
|
|
35
|
+
[sys.executable, "-E", "-s", "-m", "pip", "install", "--upgrade", "certifi"]
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
import certifi
|
|
39
|
+
|
|
40
|
+
# change working directory to the default SSL directory
|
|
41
|
+
os.chdir(openssl_dir)
|
|
42
|
+
relpath_to_certifi_cafile = os.path.relpath(certifi.where())
|
|
43
|
+
print(" -- removing any existing file or link")
|
|
44
|
+
with contextlib.suppress(FileNotFoundError):
|
|
45
|
+
os.remove(openssl_cafile)
|
|
46
|
+
print(" -- creating symlink to certifi certificate bundle")
|
|
47
|
+
os.symlink(relpath_to_certifi_cafile, openssl_cafile)
|
|
48
|
+
print(" -- setting permissions")
|
|
49
|
+
os.chmod(openssl_cafile, STAT_0o775)
|
|
50
|
+
print(" -- update complete")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
if __name__ == "__main__":
|
|
54
|
+
main()
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import progressbar
|
|
6
|
+
from scipy.sparse import csr_matrix
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def to_tuple(element) -> tuple:
|
|
13
|
+
"""Whether single element or tuple, always returns as tuple."""
|
|
14
|
+
if isinstance(element, tuple):
|
|
15
|
+
return element
|
|
16
|
+
else:
|
|
17
|
+
return (element,)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def arg_to_str(arg: type | str) -> str:
|
|
21
|
+
"""Converts a type to its name or returns the string.
|
|
22
|
+
|
|
23
|
+
:param arg: Argument to convert to string.
|
|
24
|
+
:type arg: Union[type, str]
|
|
25
|
+
:return: String representation of the argument.
|
|
26
|
+
:rtype: str
|
|
27
|
+
:raises TypeError: If the argument is not a string or a type.
|
|
28
|
+
"""
|
|
29
|
+
if isinstance(arg, type):
|
|
30
|
+
return arg.__name__
|
|
31
|
+
elif not isinstance(arg, str):
|
|
32
|
+
raise TypeError(f"Argument should be string or type, not {type(arg)}!")
|
|
33
|
+
return arg
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def df_to_sparse(
|
|
37
|
+
df,
|
|
38
|
+
item_ix,
|
|
39
|
+
user_ix,
|
|
40
|
+
value_ix=None,
|
|
41
|
+
shape=None,
|
|
42
|
+
) -> csr_matrix:
|
|
43
|
+
if value_ix is not None and value_ix in df:
|
|
44
|
+
values = df[value_ix]
|
|
45
|
+
else:
|
|
46
|
+
if value_ix is not None:
|
|
47
|
+
# value_ix provided, but not in df
|
|
48
|
+
logger.warning(f"Value column {value_ix} not found in dataframe. Using ones instead.")
|
|
49
|
+
|
|
50
|
+
num_entries = df.shape[0]
|
|
51
|
+
# Scipy sums up the entries when an index-pair occurs more than once,
|
|
52
|
+
# resulting in the actual counts being stored. Neat!
|
|
53
|
+
values = np.ones(num_entries)
|
|
54
|
+
|
|
55
|
+
indices = list(zip(*df.loc[:, [user_ix, item_ix]].values))
|
|
56
|
+
|
|
57
|
+
if indices == []:
|
|
58
|
+
indices = [[], []] # Empty zip does not evaluate right
|
|
59
|
+
|
|
60
|
+
if shape is None:
|
|
61
|
+
shape = df[user_ix].max() + 1, df[item_ix].max() + 1
|
|
62
|
+
sparse_matrix = csr_matrix((values, indices), shape=shape, dtype=values.dtype)
|
|
63
|
+
|
|
64
|
+
return sparse_matrix
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def to_binary(X: csr_matrix) -> csr_matrix:
|
|
68
|
+
"""Converts a matrix to binary by setting all non-zero values to 1.
|
|
69
|
+
|
|
70
|
+
:param X: Matrix to convert to binary.
|
|
71
|
+
:type X: csr_matrix
|
|
72
|
+
:return: Binary matrix.
|
|
73
|
+
:rtype: csr_matrix
|
|
74
|
+
"""
|
|
75
|
+
X_binary = X.astype(bool).astype(X.dtype)
|
|
76
|
+
|
|
77
|
+
return X_binary
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def invert(x: Union[np.ndarray, csr_matrix]) -> Union[np.ndarray, csr_matrix]:
|
|
81
|
+
"""Invert an array.
|
|
82
|
+
|
|
83
|
+
:param x: [description]
|
|
84
|
+
:type x: [type]
|
|
85
|
+
:return: [description]
|
|
86
|
+
:rtype: [type]
|
|
87
|
+
"""
|
|
88
|
+
if isinstance(x, np.ndarray):
|
|
89
|
+
ret = np.zeros(x.shape)
|
|
90
|
+
elif isinstance(x, csr_matrix):
|
|
91
|
+
ret = csr_matrix(x.shape)
|
|
92
|
+
else:
|
|
93
|
+
raise TypeError("Unsupported type for argument x.")
|
|
94
|
+
ret[x.nonzero()] = 1 / x[x.nonzero()]
|
|
95
|
+
return ret
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ProgressBar:
|
|
99
|
+
"""Progress bar as visual."""
|
|
100
|
+
|
|
101
|
+
def __init__(self) -> None:
|
|
102
|
+
self.pbar = None
|
|
103
|
+
|
|
104
|
+
def __call__(self, block_num, block_size, total_size) -> None:
|
|
105
|
+
if not self.pbar:
|
|
106
|
+
self.pbar = progressbar.ProgressBar(maxval=total_size)
|
|
107
|
+
self.pbar.start()
|
|
108
|
+
|
|
109
|
+
downloaded = block_num * block_size
|
|
110
|
+
if downloaded < total_size:
|
|
111
|
+
self.pbar.update(downloaded)
|
|
112
|
+
else:
|
|
113
|
+
self.pbar.finish()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def add_rows_to_csr_matrix(matrix: csr_matrix, n: int = 1) -> csr_matrix:
|
|
117
|
+
"""Add a row of zeros to a csr_matrix.
|
|
118
|
+
|
|
119
|
+
ref: https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr
|
|
120
|
+
|
|
121
|
+
:param matrix: Matrix to add a row of zeros to.
|
|
122
|
+
:type matrix: csr_matrix
|
|
123
|
+
:return: Matrix with a row of zeros added.
|
|
124
|
+
:rtype: csr_matrix
|
|
125
|
+
"""
|
|
126
|
+
new_shape = (matrix.shape[0] + n, matrix.shape[1])
|
|
127
|
+
new_indptr = np.append(matrix.indptr, [matrix.indptr[-1]] * n)
|
|
128
|
+
matrix = csr_matrix((matrix.data, matrix.indices, new_indptr), shape=new_shape, copy=False)
|
|
129
|
+
return matrix
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def add_columns_to_csr_matrix(matrix: csr_matrix, n: int = 1) -> csr_matrix:
|
|
133
|
+
"""Add a column of zeros to a csr_matrix.
|
|
134
|
+
|
|
135
|
+
https://stackoverflow.com/questions/30691160/effectively-change-dimension-of-scipy-spare-csr-matrix
|
|
136
|
+
|
|
137
|
+
:param matrix: Matrix to add a column of zeros to.
|
|
138
|
+
:type matrix: csr_matrix
|
|
139
|
+
:return: Matrix with a column of zeros added.
|
|
140
|
+
:rtype: csr_matrix
|
|
141
|
+
"""
|
|
142
|
+
new_shape = (matrix.shape[0], matrix.shape[1] + n)
|
|
143
|
+
matrix = csr_matrix(
|
|
144
|
+
(matrix.data, matrix.indices, matrix.indptr),
|
|
145
|
+
shape=new_shape,
|
|
146
|
+
copy=False,
|
|
147
|
+
)
|
|
148
|
+
return matrix
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def set_row_csr_addition(A: csr_matrix, row_idx: int, new_row: np.ndarray) -> None:
|
|
152
|
+
"""Set row of a csr_matrix to a new row.
|
|
153
|
+
|
|
154
|
+
ref: https://stackoverflow.com/questions/28427236/set-row-of-csr-matrix
|
|
155
|
+
|
|
156
|
+
:param A: Matrix to set a row of.
|
|
157
|
+
:type A: csr_matrix
|
|
158
|
+
:param row_idx: Index of the row to set.
|
|
159
|
+
:type row_idx: int
|
|
160
|
+
:param new_row: New row to set.
|
|
161
|
+
:type new_row: np.ndarray
|
|
162
|
+
"""
|
|
163
|
+
indptr = np.zeros(A.shape[1] + 1)
|
|
164
|
+
indptr[row_idx + 1 :] = A.shape[1]
|
|
165
|
+
indices = np.arange(A.shape[1])
|
|
166
|
+
A += csr_matrix((new_row, indices, indptr), shape=A.shape)
|