imbed_data_prep 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imbed_data_prep/__init__.py +12 -0
- imbed_data_prep/arxiv/README.md +31 -0
- imbed_data_prep/arxiv/__init__.py +10 -0
- imbed_data_prep/embeddings_of_aggregations/README.md +46 -0
- imbed_data_prep/embeddings_of_aggregations/__init__.py +333 -0
- imbed_data_prep/embeddings_of_aggregations/embeddings_and_order.ipynb +5573 -0
- imbed_data_prep/epstein_files/README.md +182 -0
- imbed_data_prep/epstein_files/__init__.py +1061 -0
- imbed_data_prep/epstein_files/epstein_files.ipynb +2071 -0
- imbed_data_prep/epstein_files/epstein_files_tables_info.json +1 -0
- imbed_data_prep/epstein_files/epstein_files_tables_info.pickle +0 -0
- imbed_data_prep/eurovis/README.md +52 -0
- imbed_data_prep/eurovis/__init__.py +146 -0
- imbed_data_prep/eurovis/eurovis.ipynb +3345 -0
- imbed_data_prep/github_repos/README.md +45 -0
- imbed_data_prep/github_repos/__init__.py +190 -0
- imbed_data_prep/github_repos/github_repos.ipynb +840 -0
- imbed_data_prep/hcp/README.md +48 -0
- imbed_data_prep/hcp/__init__.py +253 -0
- imbed_data_prep/hcp/hcp_analysis.ipynb +3886 -0
- imbed_data_prep/jersey_laws/README.md +45 -0
- imbed_data_prep/jersey_laws/__init__.py +85 -0
- imbed_data_prep/jersey_laws/jersey_laws.ipynb +509 -0
- imbed_data_prep/lmsys_ai_conversations/README.md +57 -0
- imbed_data_prep/lmsys_ai_conversations/__init__.py +786 -0
- imbed_data_prep/mcdonalds_reviews/README.md +51 -0
- imbed_data_prep/mcdonalds_reviews/__init__.py +463 -0
- imbed_data_prep/mcdonalds_reviews/mcdonalds_reviews_dacc.ipynb +1240 -0
- imbed_data_prep/prompt_injections/README.md +44 -0
- imbed_data_prep/prompt_injections/__init__.py +63 -0
- imbed_data_prep/prompt_injections/prompt_injection_w_umap_embeddings.tsv +691 -0
- imbed_data_prep/trump_vs_zelenskyy/README.md +60 -0
- imbed_data_prep/trump_vs_zelenskyy/__init__.py +569 -0
- imbed_data_prep/trump_vs_zelenskyy/trump_vs_zelensky.md +448 -0
- imbed_data_prep/trump_vs_zelenskyy/trump_vs_zelenskyy.ipynb +3363 -0
- imbed_data_prep/trump_vs_zelenskyy/trump_vs_zelenskyy_embeddings.parquet +0 -0
- imbed_data_prep/trump_vs_zelenskyy/trump_vs_zelenskyy_transcript.parquet +0 -0
- imbed_data_prep/twitter_sentiment/README.md +47 -0
- imbed_data_prep/twitter_sentiment/__init__.py +174 -0
- imbed_data_prep/twitter_sentiment/twitter_sentiment.ipynb +616 -0
- imbed_data_prep/ultra_chat/README.md +37 -0
- imbed_data_prep/ultra_chat/__init__.py +10 -0
- imbed_data_prep/ultra_chat/ultra_chat.ipynb +229 -0
- imbed_data_prep/wildchat/README.md +54 -0
- imbed_data_prep/wildchat/__init__.py +265 -0
- imbed_data_prep/wildchat/wildchat.ipynb +7787 -0
- imbed_data_prep/wordnet_words/README.md +77 -0
- imbed_data_prep/wordnet_words/__init__.py +1212 -0
- imbed_data_prep/wordnet_words/test_synset_refactor.ipynb +229 -0
- imbed_data_prep/wordnet_words/wordnet_words.ipynb +4341 -0
- imbed_data_prep-0.1.1.dist-info/METADATA +41 -0
- imbed_data_prep-0.1.1.dist-info/RECORD +54 -0
- imbed_data_prep-0.1.1.dist-info/WHEEL +4 -0
- imbed_data_prep-0.1.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Modules to acquire and prepare data for the imbed package."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# # Use __getitem__ to protect access of the modules of list_of_modules
|
|
5
|
+
|
|
6
|
+
# def __getitem__(name):
|
|
7
|
+
# # if name in list_of_modules:
|
|
8
|
+
# try:
|
|
9
|
+
# return globals()[name]
|
|
10
|
+
# except KeyError:
|
|
11
|
+
# pass # will raise ImportError below
|
|
12
|
+
# raise ImportError(f"No module named {name}")
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# ArXiv
|
|
2
|
+
|
|
3
|
+
Data preparation for ArXiv papers.
|
|
4
|
+
|
|
5
|
+
## Status
|
|
6
|
+
|
|
7
|
+
This module has been migrated to the standalone
|
|
8
|
+
[`xv`](https://pypi.org/project/xv/) package on PyPI.
|
|
9
|
+
The module here is a thin wrapper that re-exports from `xv`.
|
|
10
|
+
|
|
11
|
+
## Data source
|
|
12
|
+
|
|
13
|
+
[ArXiv](https://arxiv.org/) is an open-access repository of scientific
|
|
14
|
+
papers in physics, mathematics, computer science, and related fields, hosted
|
|
15
|
+
by Cornell University.
|
|
16
|
+
|
|
17
|
+
## Usage
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install xv
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
from imbed_data_prep.arxiv import ... # delegates to xv
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Files in this directory
|
|
28
|
+
|
|
29
|
+
| File | Description |
|
|
30
|
+
|---|---|
|
|
31
|
+
| `__init__.py` | Wrapper module importing from the `xv` package |
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Embeddings of Aggregations
|
|
2
|
+
|
|
3
|
+
Experiments with aggregated embeddings over citation graphs.
|
|
4
|
+
|
|
5
|
+
## Data source
|
|
6
|
+
|
|
7
|
+
This module works with citation graph data and academic paper metadata.
|
|
8
|
+
It takes a set of nodes (papers) with embeddings and citation links, then
|
|
9
|
+
explores how aggregating the titles of cited papers and embedding those
|
|
10
|
+
aggregated strings compares to the original embeddings.
|
|
11
|
+
|
|
12
|
+
The data is expected to come from an external citation graph (e.g. Semantic
|
|
13
|
+
Scholar, OpenAlex, or a custom corpus) loaded as DataFrames with paper IDs,
|
|
14
|
+
titles, and citation edges.
|
|
15
|
+
|
|
16
|
+
## What it does
|
|
17
|
+
|
|
18
|
+
1. **Sample nodes** from a citation graph.
|
|
19
|
+
2. **Permute citations** -- for each citing paper, generate multiple random
|
|
20
|
+
orderings of its cited papers.
|
|
21
|
+
3. **Aggregate titles** -- concatenate cited-paper titles in each permutation
|
|
22
|
+
order into a single string.
|
|
23
|
+
4. **Embed aggregated strings** -- compute embeddings of these concatenated
|
|
24
|
+
title strings.
|
|
25
|
+
5. **Compare** -- measure how the aggregated embeddings relate to the
|
|
26
|
+
original paper embeddings, exploring whether citation context captures
|
|
27
|
+
similar semantic information.
|
|
28
|
+
|
|
29
|
+
## Output
|
|
30
|
+
|
|
31
|
+
A DataFrame with columns:
|
|
32
|
+
|
|
33
|
+
| Column | Description |
|
|
34
|
+
|---|---|
|
|
35
|
+
| `citing_id` | ID of the citing paper |
|
|
36
|
+
| `n_cited` | Number of papers cited |
|
|
37
|
+
| `permutation_index` | Index of this particular citation ordering |
|
|
38
|
+
| `aggregated_title` | Concatenated cited-paper titles |
|
|
39
|
+
| `embedding` | Embedding vector of the aggregated title string |
|
|
40
|
+
|
|
41
|
+
## Files in this directory
|
|
42
|
+
|
|
43
|
+
| File | Description |
|
|
44
|
+
|---|---|
|
|
45
|
+
| `__init__.py` | Module code |
|
|
46
|
+
| `embeddings_and_order.ipynb` | Notebook exploring aggregation experiments |
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
"""Tools to analyze the embeddings of aggregations"""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from dol import Pipe
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from typing import List, TypeVar, Tuple
|
|
7
|
+
from collections.abc import Mapping, Callable, Iterable
|
|
8
|
+
import oa
|
|
9
|
+
|
|
10
|
+
# `simple_semantic_features` was moved/renamed in `imbed`; the current equivalent
|
|
11
|
+
# is `three_text_features` in `imbed.components.vectorization`.
|
|
12
|
+
from imbed.components.vectorization import (
|
|
13
|
+
three_text_features as simple_semantic_features,
|
|
14
|
+
)
|
|
15
|
+
from imbed.util import fuzzy_induced_graph as fuzzy_induced_graph, Node, Nodes
|
|
16
|
+
|
|
17
|
+
# DFLT_EMBEDDING_FUNC = oa.embeddings
|
|
18
|
+
DFLT_EMBEDDING_FUNC = simple_semantic_features
|
|
19
|
+
DFLT_RANDOM_SEED = 0
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_n_unique_permutations(arr, n: int, seed: int = DFLT_RANDOM_SEED):
|
|
23
|
+
"""
|
|
24
|
+
Get n unique permutations of an array, with a random seed fixed, and
|
|
25
|
+
raise an error if n is larger than the number of possible permutations.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
arr (list): The list to permute.
|
|
29
|
+
n (int): The number of unique permutations to generate.
|
|
30
|
+
seed (int): The random seed for reproducibility.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
list: A list of unique permutations.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If n is larger than the number of possible permutations.
|
|
37
|
+
|
|
38
|
+
Examples:
|
|
39
|
+
|
|
40
|
+
The exact permutations selected/ordered depend on the RNG, which is
|
|
41
|
+
not stable across Python/NumPy versions, so these examples are not
|
|
42
|
+
run as doctests (see test_get_n_unique_permutations for a stable,
|
|
43
|
+
contract-based check).
|
|
44
|
+
|
|
45
|
+
>>> get_n_unique_permutations([1, 2, 3], 2) # doctest: +SKIP
|
|
46
|
+
[(3, 2, 1), (3, 1, 2)]
|
|
47
|
+
>>> get_n_unique_permutations([1, 2, 3], 2, seed=0) # doctest: +SKIP
|
|
48
|
+
[(3, 2, 1), (3, 1, 2)]
|
|
49
|
+
>>> get_n_unique_permutations([1, 2, 3], 2, seed=1) # doctest: +SKIP
|
|
50
|
+
[(2, 3, 1), (1, 3, 2)]
|
|
51
|
+
>>> get_n_unique_permutations([1, 2, 3], 6) # doctest: +SKIP
|
|
52
|
+
[(1, 3, 2), (1, 2, 3), (2, 1, 3), (3, 2, 1), (3, 1, 2), (2, 3, 1)]
|
|
53
|
+
>>> get_n_unique_permutations([1, 2, 3], 7)
|
|
54
|
+
Traceback (most recent call last):
|
|
55
|
+
...
|
|
56
|
+
ValueError: n (=7) is larger than the number of possible permutations: 6
|
|
57
|
+
"""
|
|
58
|
+
import numpy as np
|
|
59
|
+
import math
|
|
60
|
+
|
|
61
|
+
np.random.seed(seed)
|
|
62
|
+
|
|
63
|
+
n_perms = math.factorial(len(arr))
|
|
64
|
+
if n > n_perms:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"n (={n}) is larger than the number of possible permutations: {n_perms}"
|
|
67
|
+
)
|
|
68
|
+
perms = set()
|
|
69
|
+
while len(perms) < n:
|
|
70
|
+
perms.add(tuple(np.random.permutation(arr)))
|
|
71
|
+
return list(perms)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def aggregated_embeddings_for_sample(
|
|
75
|
+
graph: Mapping[Node, Nodes],
|
|
76
|
+
n_nodes: int,
|
|
77
|
+
n_permutations: int,
|
|
78
|
+
*,
|
|
79
|
+
node_to_text: Callable[[Node], str],
|
|
80
|
+
aggregate_texts: callable = "\n\n".join,
|
|
81
|
+
text_to_embedding: callable = None,
|
|
82
|
+
max_permutations: int = 100,
|
|
83
|
+
seed: int = 0,
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Get a (without replacement) sample of n_nodes items of the citation_graph
|
|
87
|
+
and for each, take n_permutations permutations of the cited_ids,
|
|
88
|
+
aggregate the titles of the cited_ids and compute its embedding.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
graph (dict): The citation graph.
|
|
92
|
+
n_nodes (int): The number of nodes to sample.
|
|
93
|
+
n_permutations (int): The number of permutations to take for each cited_ids list.
|
|
94
|
+
node_to_text (callable): A function that takes a node (ID) and returns the text to embed.
|
|
95
|
+
aggregate_texts (callable): A function that takes a list of texts and returns a single aggregated text.
|
|
96
|
+
text_to_embedding (callable): A function that takes a text and returns an embedding.
|
|
97
|
+
max_permutations (int): The maximum number of permutations to take.
|
|
98
|
+
seed (int): The random seed for reproducibility.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
An iterable of dicts with aggregated embeddings for each sample.
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
Examples:
|
|
105
|
+
|
|
106
|
+
>>> graph = {
|
|
107
|
+
... 'paper1': ['paper2', 'paper3'],
|
|
108
|
+
... 'paper2': ['paper3'],
|
|
109
|
+
... 'paper3': ['paper1'],
|
|
110
|
+
... 'paper4': [],
|
|
111
|
+
... }
|
|
112
|
+
>>> node_titles = {
|
|
113
|
+
... 'paper1': 'Title of Paper 1',
|
|
114
|
+
... 'paper2': 'Title of Paper 2',
|
|
115
|
+
... 'paper3': 'Title of Paper 3',
|
|
116
|
+
... 'paper4': 'Title of Paper 4',
|
|
117
|
+
... }
|
|
118
|
+
>>> n_nodes = 2
|
|
119
|
+
>>> n_permutations = 2
|
|
120
|
+
>>> list(aggregated_embeddings_for_sample(
|
|
121
|
+
... graph, n_nodes, n_permutations, node_to_text=node_titles.get,
|
|
122
|
+
... text_to_embedding=simple_semantic_features, seed=42
|
|
123
|
+
... )) # doctest: +SKIP
|
|
124
|
+
[{'citing_id': 'paper2',
|
|
125
|
+
'permutation_index': 0,
|
|
126
|
+
'aggregated_title': 'Title of Paper 2\n\nTitle of Paper 3',
|
|
127
|
+
'embedding': (8, 26, 0)},
|
|
128
|
+
{'citing_id': 'paper1',
|
|
129
|
+
'permutation_index': 0,
|
|
130
|
+
'aggregated_title': 'Title of Paper 1\n\nTitle of Paper 2\n\nTitle of Paper 3',
|
|
131
|
+
'embedding': (12, 39, 0)},
|
|
132
|
+
{'citing_id': 'paper1',
|
|
133
|
+
'permutation_index': 1,
|
|
134
|
+
'aggregated_title': 'Title of Paper 1\n\nTitle of Paper 3\n\nTitle of Paper 2',
|
|
135
|
+
'embedding': (12, 39, 0)}]
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
import random
|
|
139
|
+
from math import factorial
|
|
140
|
+
|
|
141
|
+
nodes = list(graph.keys())
|
|
142
|
+
if n_nodes > len(nodes):
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"n_nodes ({n_nodes}) is larger than the number of nodes in the citation_graph ({len(nodes)})"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
np.random.seed(seed)
|
|
148
|
+
sampled_nodes = random.sample(nodes, n_nodes)
|
|
149
|
+
|
|
150
|
+
for citing_id in sampled_nodes:
|
|
151
|
+
neighbor_nodes = graph[citing_id]
|
|
152
|
+
if len(neighbor_nodes) == 0:
|
|
153
|
+
continue # skip nodes with no citations
|
|
154
|
+
|
|
155
|
+
citing_title = node_to_text(citing_id)
|
|
156
|
+
|
|
157
|
+
n_perms = min(
|
|
158
|
+
min(n_permutations, max_permutations), factorial(len(neighbor_nodes))
|
|
159
|
+
)
|
|
160
|
+
perms = get_n_unique_permutations(neighbor_nodes, n_perms, seed=seed)
|
|
161
|
+
|
|
162
|
+
for idx, perm in enumerate(perms):
|
|
163
|
+
aggregated_title = aggregate_texts(
|
|
164
|
+
[citing_title] + [node_to_text(neighbor_node) for neighbor_node in perm]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
d = {
|
|
168
|
+
"citing_id": citing_id,
|
|
169
|
+
"n_cited": len(neighbor_nodes),
|
|
170
|
+
"permutation_index": idx,
|
|
171
|
+
"aggregated_title": aggregated_title,
|
|
172
|
+
}
|
|
173
|
+
if text_to_embedding:
|
|
174
|
+
d["embedding"] = text_to_embedding(aggregated_title)
|
|
175
|
+
yield d
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
get_aggregated_embeddings_for_sample = Pipe(
|
|
179
|
+
aggregated_embeddings_for_sample, list, pd.DataFrame
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# -------------------------------------------------------------------------------------
|
|
183
|
+
# Tests
|
|
184
|
+
|
|
185
|
+
# `simple_semantic_features` was moved/renamed in `imbed`; the current equivalent
|
|
186
|
+
# is `three_text_features` in `imbed.components.vectorization`.
|
|
187
|
+
from imbed.components.vectorization import (
|
|
188
|
+
three_text_features as simple_semantic_features,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def test_get_n_unique_permutations():
|
|
193
|
+
arr = [1, 2, 3]
|
|
194
|
+
n = 2
|
|
195
|
+
perms = get_n_unique_permutations(arr, n, seed=0)
|
|
196
|
+
# Contract-based checks (RNG selection is not stable across versions):
|
|
197
|
+
assert len(perms) == n, "Wrong number of permutations returned"
|
|
198
|
+
assert len(set(perms)) == n, "Permutations are not unique"
|
|
199
|
+
assert all(sorted(p) == sorted(arr) for p in perms), (
|
|
200
|
+
"Each result must be a permutation of the input array"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_get_n_unique_permutations_error():
|
|
205
|
+
arr = [1, 2, 3]
|
|
206
|
+
n = 7 # There are only 6 possible permutations
|
|
207
|
+
try:
|
|
208
|
+
perms = get_n_unique_permutations(arr, n)
|
|
209
|
+
except ValueError as e:
|
|
210
|
+
assert str(e) == "n (=7) is larger than the number of possible permutations: 6"
|
|
211
|
+
else:
|
|
212
|
+
assert False, "ValueError was not raised when expected"
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
from collections.abc import Sequence
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _is_vector(v):
|
|
219
|
+
if not isinstance(v, Sequence):
|
|
220
|
+
return False
|
|
221
|
+
else:
|
|
222
|
+
first_element = next(iter(v), None)
|
|
223
|
+
return isinstance(first_element, (int, float))
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def test_get_aggregated_embeddings_for_sample():
|
|
227
|
+
citation_graph = {
|
|
228
|
+
"paper1": ["paper2", "paper3"],
|
|
229
|
+
"paper2": ["paper3"],
|
|
230
|
+
"paper3": ["paper1"],
|
|
231
|
+
"paper4": [],
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
node_titles = {
|
|
235
|
+
"paper1": "Title of Paper 1",
|
|
236
|
+
"paper2": "Title of Paper 2",
|
|
237
|
+
"paper3": "Title of Paper 3",
|
|
238
|
+
"paper4": "Title of Paper 4",
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
n_nodes = 2
|
|
242
|
+
n_permutations = 2
|
|
243
|
+
|
|
244
|
+
df = get_aggregated_embeddings_for_sample(
|
|
245
|
+
citation_graph,
|
|
246
|
+
n_nodes,
|
|
247
|
+
n_permutations,
|
|
248
|
+
node_to_text=node_titles.get,
|
|
249
|
+
text_to_embedding=simple_semantic_features,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
assert not df.empty, "DataFrame is empty"
|
|
253
|
+
assert len(df) <= n_nodes * n_permutations, "DataFrame has more rows than expected"
|
|
254
|
+
|
|
255
|
+
# Check that embeddings are numpy arrays
|
|
256
|
+
import numpy as np
|
|
257
|
+
|
|
258
|
+
assert all(map(_is_vector, df["embedding"])), "Embeddings are not vectors"
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def test_node_with_no_citations():
|
|
262
|
+
citation_graph = {
|
|
263
|
+
"paper1": ["paper2", "paper3"],
|
|
264
|
+
"paper2": ["paper3"],
|
|
265
|
+
"paper3": ["paper1"],
|
|
266
|
+
"paper4": [],
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
node_titles = {
|
|
270
|
+
"paper1": "Title of Paper 1",
|
|
271
|
+
"paper2": "Title of Paper 2",
|
|
272
|
+
"paper3": "Title of Paper 3",
|
|
273
|
+
"paper4": "Title of Paper 4",
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
n_nodes = 4 # All nodes
|
|
277
|
+
n_permutations = 2
|
|
278
|
+
|
|
279
|
+
df = get_aggregated_embeddings_for_sample(
|
|
280
|
+
citation_graph,
|
|
281
|
+
n_nodes,
|
|
282
|
+
n_permutations,
|
|
283
|
+
node_to_text=node_titles.get,
|
|
284
|
+
text_to_embedding=simple_semantic_features,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Check that 'paper4' (node with no citations) is not in df['citing_id']
|
|
288
|
+
assert "paper4" not in df["citing_id"].values, (
|
|
289
|
+
"Node with no citations should be skipped"
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def test_n_nodes_too_large():
|
|
294
|
+
citation_graph = {
|
|
295
|
+
"paper1": ["paper2", "paper3"],
|
|
296
|
+
"paper2": ["paper3"],
|
|
297
|
+
"paper3": ["paper1"],
|
|
298
|
+
"paper4": [],
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
node_titles = {
|
|
302
|
+
"paper1": "Title of Paper 1",
|
|
303
|
+
"paper2": "Title of Paper 2",
|
|
304
|
+
"paper3": "Title of Paper 3",
|
|
305
|
+
"paper4": "Title of Paper 4",
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
n_nodes = 5 # There are only 4 nodes in citation_graph
|
|
309
|
+
n_permutations = 2
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
df = get_aggregated_embeddings_for_sample(
|
|
313
|
+
citation_graph,
|
|
314
|
+
n_nodes,
|
|
315
|
+
n_permutations,
|
|
316
|
+
node_to_text=node_titles.get,
|
|
317
|
+
text_to_embedding=simple_semantic_features,
|
|
318
|
+
)
|
|
319
|
+
except ValueError as e:
|
|
320
|
+
assert (
|
|
321
|
+
"n_nodes (5) is larger than the number of nodes in the citation_graph (4)"
|
|
322
|
+
in str(e)
|
|
323
|
+
)
|
|
324
|
+
else:
|
|
325
|
+
assert False, "ValueError was not raised when expected"
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def test_all_permutation_tools():
|
|
329
|
+
test_get_n_unique_permutations()
|
|
330
|
+
test_get_n_unique_permutations_error()
|
|
331
|
+
test_get_aggregated_embeddings_for_sample()
|
|
332
|
+
test_node_with_no_citations()
|
|
333
|
+
test_n_nodes_too_large()
|