emb-diversity 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emb_diversity/__init__.py +63 -0
- emb_diversity/_accepts_text.py +104 -0
- emb_diversity/_registry.py +51 -0
- emb_diversity/axes_registry.py +54 -0
- emb_diversity/cli.py +212 -0
- emb_diversity/compute_pairwise.py +144 -0
- emb_diversity/convenience.py +87 -0
- emb_diversity/embed.py +45 -0
- emb_diversity/embeddings/SBERT.py +38 -0
- emb_diversity/embeddings/SimCSE.py +24 -0
- emb_diversity/embeddings/__init__.py +0 -0
- emb_diversity/embeddings/embed.py +72 -0
- emb_diversity/eval/__init__.py +0 -0
- emb_diversity/eval/data/STEL.py +76 -0
- emb_diversity/eval/data/__init__.py +0 -0
- emb_diversity/eval/data/synthstel.py +137 -0
- emb_diversity/evaluate_measures.py +260 -0
- emb_diversity/measures/__init__.py +0 -0
- emb_diversity/measures/_types.py +14 -0
- emb_diversity/measures/bins_entropy.py +169 -0
- emb_diversity/measures/bottleneck.py +47 -0
- emb_diversity/measures/chamfer_dist.py +64 -0
- emb_diversity/measures/cluster_inertia.py +59 -0
- emb_diversity/measures/convex_hull_volume_2d.py +109 -0
- emb_diversity/measures/dcscore.py +100 -0
- emb_diversity/measures/diameter.py +37 -0
- emb_diversity/measures/dist_dispersion.py +38 -0
- emb_diversity/measures/energy.py +50 -0
- emb_diversity/measures/graph_entropy.py +81 -0
- emb_diversity/measures/hamdiv.py +107 -0
- emb_diversity/measures/log_determinant.py +128 -0
- emb_diversity/measures/mean_pw_dist.py +37 -0
- emb_diversity/measures/mst_dispersion.py +47 -0
- emb_diversity/measures/radius.py +52 -0
- emb_diversity/measures/renyi_entropy.py +140 -0
- emb_diversity/measures/span_centroid.py +55 -0
- emb_diversity/measures/span_medoid.py +42 -0
- emb_diversity/measures/sum_bottleneck.py +56 -0
- emb_diversity/measures/sum_diameter.py +55 -0
- emb_diversity/measures/utils.py +27 -0
- emb_diversity/measures/vendi_score.py +76 -0
- emb_diversity/measures_registry.py +68 -0
- emb_diversity/plot/__init__.py +0 -0
- emb_diversity/two_d.py +230 -0
- emb_diversity/utility/__init__.py +3 -0
- emb_diversity/utility/_cache.py +85 -0
- emb_diversity/utility/project_root.py +17 -0
- emb_diversity-0.0.3.dist-info/METADATA +386 -0
- emb_diversity-0.0.3.dist-info/RECORD +53 -0
- emb_diversity-0.0.3.dist-info/WHEEL +5 -0
- emb_diversity-0.0.3.dist-info/entry_points.txt +2 -0
- emb_diversity-0.0.3.dist-info/licenses/LICENSE +21 -0
- emb_diversity-0.0.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""emb-diversity – embedding-based diversity measures for text and vector data."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
### Distance-Based Diversity Measures
|
|
6
|
+
from .measures.mean_pw_dist import mean_pw_dist
|
|
7
|
+
from .measures.dist_dispersion import dist_dispersion
|
|
8
|
+
from .measures.hamdiv import hamdiv
|
|
9
|
+
from .measures.diameter import diameter
|
|
10
|
+
from .measures.bottleneck import bottleneck
|
|
11
|
+
from .measures.sum_bottleneck import sum_bottleneck
|
|
12
|
+
from .measures.sum_diameter import sum_diameter
|
|
13
|
+
from .measures.energy import energy
|
|
14
|
+
from .measures.cluster_inertia import cluster_inertia
|
|
15
|
+
from .measures.span_centroid import span_centroid
|
|
16
|
+
from .measures.chamfer_dist import chamfer_dist
|
|
17
|
+
|
|
18
|
+
### Volume-Based Diversity Measures
|
|
19
|
+
from .measures.convex_hull_volume_2d import convex_hull_volume_2d
|
|
20
|
+
from .measures.radius import radius
|
|
21
|
+
from .measures.span_medoid import span_medoid
|
|
22
|
+
|
|
23
|
+
### Distribution-Based Diversity Measures
|
|
24
|
+
from .measures.vendi_score import vendi_score
|
|
25
|
+
from .measures.renyi_entropy import renyi_entropy
|
|
26
|
+
from .measures.dcscore import dcscore
|
|
27
|
+
from .measures.log_determinant import log_determinant
|
|
28
|
+
from .measures.bins_entropy import bins_entropy
|
|
29
|
+
|
|
30
|
+
### Graph-Based Diversity Measures
|
|
31
|
+
from .measures.graph_entropy import graph_entropy
|
|
32
|
+
from .measures.mst_dispersion import mst_dispersion
|
|
33
|
+
|
|
34
|
+
### Registries
|
|
35
|
+
from .axes_registry import axes
|
|
36
|
+
from .measures_registry import measures
|
|
37
|
+
|
|
38
|
+
### Embedding helper
|
|
39
|
+
from .embed import embed_texts
|
|
40
|
+
|
|
41
|
+
### Main entry point
|
|
42
|
+
from .convenience import measure_diversity
|
|
43
|
+
|
|
44
|
+
### Caching utilities
|
|
45
|
+
from .compute_pairwise import compute_pairwise_distances, clear_distance_cache, distance_cache_info
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
__all__ = [
|
|
49
|
+
# Main entry point
|
|
50
|
+
"measure_diversity",
|
|
51
|
+
# Individual measures
|
|
52
|
+
"mean_pw_dist", "dist_dispersion", "hamdiv", "diameter", "bottleneck",
|
|
53
|
+
"sum_bottleneck", "sum_diameter", "energy", "cluster_inertia", "span_centroid", "chamfer_dist",
|
|
54
|
+
"convex_hull_volume_2d", "radius", "span_medoid", "vendi_score",
|
|
55
|
+
"renyi_entropy", "dcscore", "log_determinant", "bins_entropy",
|
|
56
|
+
"graph_entropy", "mst_dispersion",
|
|
57
|
+
# Helpers
|
|
58
|
+
"embed_texts",
|
|
59
|
+
# Registries
|
|
60
|
+
"axes", "measures",
|
|
61
|
+
# Pairwise distance caching
|
|
62
|
+
"compute_pairwise_distances", "clear_distance_cache", "distance_cache_info",
|
|
63
|
+
]
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Decorator that lets measure functions accept raw text.
|
|
2
|
+
|
|
3
|
+
When a measure function is decorated with ``@accepts_text``, it gains
|
|
4
|
+
the ability to receive a list of strings instead of embeddings. The
|
|
5
|
+
decorator detects text input, embeds it via :func:`emb_diversity.embed.embed_texts`,
|
|
6
|
+
and passes the resulting vectors to the original measure function.
|
|
7
|
+
|
|
8
|
+
Two keyword arguments are added to the decorated function:
|
|
9
|
+
|
|
10
|
+
- ``diversity_axis`` (default ``"semantic"``): which axis to use for embedding
|
|
11
|
+
- ``embedding_model``: explicit model name, overrides the axis
|
|
12
|
+
|
|
13
|
+
If the input is already numeric (e.g. a numpy array), the decorator
|
|
14
|
+
passes it through unchanged.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import functools
|
|
20
|
+
import inspect
|
|
21
|
+
from typing import Sequence
|
|
22
|
+
|
|
23
|
+
from .embed import embed_texts
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# ── Input detection ──────────────────────────────────────────────────
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _is_text_input(data):
|
|
30
|
+
"""Return True if *data* looks like a list of strings."""
|
|
31
|
+
return len(data) > 0 and isinstance(data[0], str)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ── Decorator ────────────────────────────────────────────────────────
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def accepts_text(func):
|
|
38
|
+
"""Wrap a measure function so it can receive raw text.
|
|
39
|
+
|
|
40
|
+
The decorated function gains two optional keyword arguments:
|
|
41
|
+
|
|
42
|
+
* ``diversity_axis`` – registered axis name (default ``"semantic"``)
|
|
43
|
+
* ``embedding_model`` – explicit model id (overrides axis)
|
|
44
|
+
|
|
45
|
+
When the first positional argument (*data*) is a list of strings the
|
|
46
|
+
decorator embeds them before calling the original measure. When *data*
|
|
47
|
+
is already numeric the decorator passes it through unchanged.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
@functools.wraps(func)
|
|
51
|
+
def wrapper(data, *args, diversity_axis="semantic", embedding_model=None, **kwargs):
|
|
52
|
+
if _is_text_input(data):
|
|
53
|
+
data = embed_texts(
|
|
54
|
+
data,
|
|
55
|
+
diversity_axis=diversity_axis,
|
|
56
|
+
embedding_model=embedding_model,
|
|
57
|
+
)
|
|
58
|
+
return func(data, *args, **kwargs)
|
|
59
|
+
|
|
60
|
+
# Preserve the original signature but add the new params for docs/IDE
|
|
61
|
+
_patch_signature(wrapper, func)
|
|
62
|
+
return wrapper
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _patch_signature(wrapper, func):
|
|
66
|
+
"""Fix the wrapper's signature so help() and IDEs show useful parameter names.
|
|
67
|
+
|
|
68
|
+
This is purely cosmetic — it only affects what help() and IDEs display,
|
|
69
|
+
not how the function actually runs.
|
|
70
|
+
|
|
71
|
+
Example for mean_pw_dist:
|
|
72
|
+
Before: mean_pw_dist(data, *args, diversity_axis="semantic", embedding_model=None, **kwargs)
|
|
73
|
+
After: mean_pw_dist(data, metric="cosine", diversity_axis="semantic", embedding_model=None, **metric_kwargs)
|
|
74
|
+
"""
|
|
75
|
+
# The two parameters that the decorator adds
|
|
76
|
+
diversity_axis_param = inspect.Parameter(
|
|
77
|
+
"diversity_axis", inspect.Parameter.KEYWORD_ONLY, default="semantic",
|
|
78
|
+
)
|
|
79
|
+
embedding_model_param = inspect.Parameter(
|
|
80
|
+
"embedding_model", inspect.Parameter.KEYWORD_ONLY, default=None,
|
|
81
|
+
)
|
|
82
|
+
new_params = [diversity_axis_param, embedding_model_param]
|
|
83
|
+
|
|
84
|
+
# Read the original function's parameters
|
|
85
|
+
# e.g. for mean_pw_dist: [data, metric="cosine", **metric_kwargs]
|
|
86
|
+
original_sig = inspect.signature(func)
|
|
87
|
+
original_params = list(original_sig.parameters.values())
|
|
88
|
+
|
|
89
|
+
# Find the **kwargs parameter if it exists (e.g. **metric_kwargs)
|
|
90
|
+
kwargs_params = [p for p in original_params if p.kind == inspect.Parameter.VAR_KEYWORD]
|
|
91
|
+
has_kwargs = len(kwargs_params) > 0
|
|
92
|
+
|
|
93
|
+
# Build the combined parameter list:
|
|
94
|
+
# original params + new params + **kwargs at the end
|
|
95
|
+
# We insert before **kwargs so the signature reads naturally:
|
|
96
|
+
# (data, metric, diversity_axis, embedding_model, **metric_kwargs)
|
|
97
|
+
if has_kwargs:
|
|
98
|
+
kwargs_position = original_params.index(kwargs_params[0])
|
|
99
|
+
params_before_kwargs = original_params[:kwargs_position]
|
|
100
|
+
combined = params_before_kwargs + new_params + [kwargs_params[0]]
|
|
101
|
+
else:
|
|
102
|
+
combined = original_params + new_params
|
|
103
|
+
|
|
104
|
+
wrapper.__signature__ = original_sig.replace(parameters=combined)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Generic registry pattern for O(1) key-value lookup.
|
|
2
|
+
|
|
3
|
+
Based on https://dev.to/dentedlogic/stop-writing-giant-if-else-chains-master-the-python-registry-pattern-ldm
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
class Registry:
|
|
9
|
+
"""A simple key-value store with register, get, and list operations.
|
|
10
|
+
|
|
11
|
+
Used by both the axes registry and the measures registry to avoid
|
|
12
|
+
repeated dictionary boilerplate.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self):
|
|
16
|
+
self._store: dict[str, object] = {}
|
|
17
|
+
|
|
18
|
+
def register(self, key: str, value: object) -> None:
|
|
19
|
+
"""Add an entry. Overwrites if key already exists."""
|
|
20
|
+
self._store[key] = value
|
|
21
|
+
|
|
22
|
+
def get(self, key: str) -> object:
|
|
23
|
+
"""Look up by key.
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
KeyError: If the key has not been registered.
|
|
27
|
+
"""
|
|
28
|
+
if key not in self._store:
|
|
29
|
+
registered = ", ".join(sorted(self._store)) or "(none)"
|
|
30
|
+
raise KeyError(f"Unknown key {key!r}. Registered: {registered}")
|
|
31
|
+
return self._store[key]
|
|
32
|
+
|
|
33
|
+
def list_all(self) -> list[object]:
|
|
34
|
+
"""Return all values sorted by key."""
|
|
35
|
+
return [self._store[k] for k in sorted(self._store)]
|
|
36
|
+
|
|
37
|
+
def keys(self) -> list[str]:
|
|
38
|
+
"""Return all registered keys."""
|
|
39
|
+
return list(self._store.keys())
|
|
40
|
+
|
|
41
|
+
def __contains__(self, key: str) -> bool:
|
|
42
|
+
return key in self._store
|
|
43
|
+
|
|
44
|
+
def __len__(self) -> int:
|
|
45
|
+
return len(self._store)
|
|
46
|
+
|
|
47
|
+
def __getitem__(self, key: str) -> object:
|
|
48
|
+
return self.get(key)
|
|
49
|
+
|
|
50
|
+
def __iter__(self):
|
|
51
|
+
return iter(self._store)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Diversity axes registry.
|
|
2
|
+
|
|
3
|
+
A diversity axis maps a concept (e.g. "semantic", "style") to a default
|
|
4
|
+
embedding model and optional alternatives.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
|
|
11
|
+
from ._registry import Registry
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class DiversityAxis:
|
|
16
|
+
"""Configuration for a diversity axis.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
name: Short identifier (e.g. ``"semantic"``).
|
|
20
|
+
default_model: HuggingFace model id used by default for this axis.
|
|
21
|
+
alternative_models: Other models that work well for this axis.
|
|
22
|
+
description: Human-readable explanation shown in docs and CLI.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
name: str
|
|
26
|
+
default_model: str
|
|
27
|
+
alternative_models: list[str] = field(default_factory=list)
|
|
28
|
+
description: str = ""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Module-level registry instance
|
|
32
|
+
axes = Registry()
|
|
33
|
+
|
|
34
|
+
# ── Built-in axes ────────────────────────────────────────────────────
|
|
35
|
+
|
|
36
|
+
axes.register(
|
|
37
|
+
"semantic",
|
|
38
|
+
DiversityAxis(
|
|
39
|
+
name="semantic",
|
|
40
|
+
default_model="all-mpnet-base-v2",
|
|
41
|
+
alternative_models=["all-MiniLM-L6-v2"],
|
|
42
|
+
description="Meaning-based diversity using semantic similarity",
|
|
43
|
+
),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
axes.register(
|
|
47
|
+
"style",
|
|
48
|
+
DiversityAxis(
|
|
49
|
+
name="style",
|
|
50
|
+
default_model="AnnaWegmann/Style-Embedding",
|
|
51
|
+
alternative_models=["StyleDistance/styledistance", "rrivera1849/LUAR-MUD", "AIDA-UPM/star"],
|
|
52
|
+
description="Writing style diversity",
|
|
53
|
+
),
|
|
54
|
+
)
|
emb_diversity/cli.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Command-line interface for emb-diversity."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import csv
|
|
6
|
+
import json
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import typer
|
|
12
|
+
|
|
13
|
+
app = typer.Typer(
|
|
14
|
+
name="emb-diversity",
|
|
15
|
+
help="Measure embedding-based diversity of text data.\n\nRun directly: emb-diversity texts.txt",
|
|
16
|
+
no_args_is_help=True,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@app.command(hidden=True)
|
|
21
|
+
def default(
|
|
22
|
+
input_file: Path = typer.Argument(
|
|
23
|
+
..., help="Path to input file (.txt, .csv, or .tsv)"
|
|
24
|
+
),
|
|
25
|
+
measures: Optional[list[str]] = typer.Option(
|
|
26
|
+
None,
|
|
27
|
+
"--measure",
|
|
28
|
+
"-m",
|
|
29
|
+
help=(
|
|
30
|
+
"Measure to run. Use -m name for one, repeat for several: "
|
|
31
|
+
"-m mean_pw_dist -m diameter. "
|
|
32
|
+
"Use -m variety|balance|disparity for a named set, -m all for all. "
|
|
33
|
+
"Default: graph_entropy."
|
|
34
|
+
),
|
|
35
|
+
),
|
|
36
|
+
axis: str = typer.Option(
|
|
37
|
+
"semantic", "--axis", "-a", help="Diversity axis (e.g. semantic, style)."
|
|
38
|
+
),
|
|
39
|
+
model: Optional[str] = typer.Option(
|
|
40
|
+
None, "--model", help="Explicit embedding model; overrides --axis."
|
|
41
|
+
),
|
|
42
|
+
column: str = typer.Option(
|
|
43
|
+
"text", "--column", "-c", help="Column name containing text (for CSV/TSV)."
|
|
44
|
+
),
|
|
45
|
+
output_format: str = typer.Option(
|
|
46
|
+
"table", "--format", "-f", help="Output format: table, json, csv."
|
|
47
|
+
),
|
|
48
|
+
) -> None:
|
|
49
|
+
"""Measure diversity from a text file or CSV/TSV."""
|
|
50
|
+
_run_measure(input_file, measures, axis, model, column, output_format)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@app.command("measure")
|
|
54
|
+
def measure_cmd(
|
|
55
|
+
input_file: Path = typer.Argument(
|
|
56
|
+
..., help="Path to input file (.txt, .csv, or .tsv)"
|
|
57
|
+
),
|
|
58
|
+
measures: Optional[list[str]] = typer.Option(
|
|
59
|
+
None,
|
|
60
|
+
"--measure",
|
|
61
|
+
"-m",
|
|
62
|
+
help=(
|
|
63
|
+
"Measure to run. Use -m name for one, repeat for several: "
|
|
64
|
+
"-m mean_pw_dist -m diameter. "
|
|
65
|
+
"Use -m variety|balance|disparity for a named set, -m all for all. "
|
|
66
|
+
"Default: graph_entropy."
|
|
67
|
+
),
|
|
68
|
+
),
|
|
69
|
+
axis: str = typer.Option(
|
|
70
|
+
"semantic", "--axis", "-a", help="Diversity axis (e.g. semantic, style)."
|
|
71
|
+
),
|
|
72
|
+
model: Optional[str] = typer.Option(
|
|
73
|
+
None, "--model", help="Explicit embedding model; overrides --axis."
|
|
74
|
+
),
|
|
75
|
+
column: str = typer.Option(
|
|
76
|
+
"text", "--column", "-c", help="Column name containing text (for CSV/TSV)."
|
|
77
|
+
),
|
|
78
|
+
output_format: str = typer.Option(
|
|
79
|
+
"table", "--format", "-f", help="Output format: table, json, csv."
|
|
80
|
+
),
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Measure diversity from a text file or CSV/TSV."""
|
|
83
|
+
_run_measure(input_file, measures, axis, model, column, output_format)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _run_measure(input_file, measures, axis, model, column, output_format):
|
|
87
|
+
"""Shared logic for the measure command."""
|
|
88
|
+
from .convenience import measure_diversity
|
|
89
|
+
|
|
90
|
+
# ── Read texts ───────────────────────────────────────────────
|
|
91
|
+
texts = _read_texts(input_file, column)
|
|
92
|
+
if len(texts) < 2:
|
|
93
|
+
typer.echo("Error: need at least 2 texts to measure diversity.", err=True)
|
|
94
|
+
raise typer.Exit(code=1)
|
|
95
|
+
|
|
96
|
+
# ── Convert Typer's list format to measure_diversity()'s format ──
|
|
97
|
+
# Typer always gives a list (e.g. ["all"]), but measure_diversity()
|
|
98
|
+
# expects a plain string for "all" and named-set shortcuts.
|
|
99
|
+
from .measures_registry import MEASURE_SETS
|
|
100
|
+
|
|
101
|
+
if measures is None:
|
|
102
|
+
measure_arg = None
|
|
103
|
+
elif len(measures) == 1 and measures[0] in ("all", *MEASURE_SETS):
|
|
104
|
+
measure_arg = measures[0]
|
|
105
|
+
else:
|
|
106
|
+
measure_arg = measures
|
|
107
|
+
|
|
108
|
+
# ── Compute ──────────────────────────────────────────────────
|
|
109
|
+
typer.echo(f"Measuring diversity of {len(texts)} texts...", err=True)
|
|
110
|
+
try:
|
|
111
|
+
results = measure_diversity(
|
|
112
|
+
texts, measure=measure_arg, diversity_axis=axis, embedding_model=model,
|
|
113
|
+
)
|
|
114
|
+
except KeyError as exc:
|
|
115
|
+
# measure_diversity() raises KeyError for unknown measure names
|
|
116
|
+
typer.echo(f"Error: {exc}. Run 'emb-diversity list-measures' to see available measures.", err=True)
|
|
117
|
+
raise typer.Exit(code=1)
|
|
118
|
+
|
|
119
|
+
# ── Output ───────────────────────────────────────────────────
|
|
120
|
+
_print_results(results, output_format)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@app.command("list-measures")
|
|
124
|
+
def list_measures_cmd() -> None:
|
|
125
|
+
"""List all available diversity measures."""
|
|
126
|
+
from .measures_registry import DEFAULT_MEASURE, MEASURE_SETS, measures
|
|
127
|
+
|
|
128
|
+
for name in sorted(measures):
|
|
129
|
+
tags = []
|
|
130
|
+
if name in DEFAULT_MEASURE:
|
|
131
|
+
tags.append("default")
|
|
132
|
+
for set_name, members in MEASURE_SETS.items():
|
|
133
|
+
if name in members:
|
|
134
|
+
tags.append(set_name)
|
|
135
|
+
suffix = f" [{', '.join(tags)}]" if tags else ""
|
|
136
|
+
typer.echo(f" {name}{suffix}")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@app.command("list-axes")
|
|
140
|
+
def list_axes_cmd() -> None:
|
|
141
|
+
"""List registered diversity axes and their models."""
|
|
142
|
+
from .axes_registry import axes
|
|
143
|
+
|
|
144
|
+
for ax in axes.list_all():
|
|
145
|
+
typer.echo(f" {ax.name}")
|
|
146
|
+
typer.echo(f" default model: {ax.default_model}")
|
|
147
|
+
if ax.alternative_models:
|
|
148
|
+
typer.echo(f" alternatives: {', '.join(ax.alternative_models)}")
|
|
149
|
+
if ax.description:
|
|
150
|
+
typer.echo(f" {ax.description}")
|
|
151
|
+
typer.echo()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# ── Helpers ──────────────────────────────────────────────────────
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _read_texts(path: Path, column: str) -> list[str]:
|
|
158
|
+
"""Read texts from a file, returning a list of non-empty strings.
|
|
159
|
+
|
|
160
|
+
Supported formats:
|
|
161
|
+
- .txt: one text per line, empty lines are skipped
|
|
162
|
+
- .csv: comma-separated, reads the column specified by `column`
|
|
163
|
+
- .tsv: tab-separated, reads the column specified by `column`
|
|
164
|
+
|
|
165
|
+
For CSV/TSV, rows with empty values in the text column are dropped.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
path: Path to the input file.
|
|
169
|
+
column: Column name to read from CSV/TSV files (default "text").
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
List of text strings, stripped of leading/trailing whitespace.
|
|
173
|
+
"""
|
|
174
|
+
import pandas as pd
|
|
175
|
+
|
|
176
|
+
suffix = path.suffix.lower()
|
|
177
|
+
if suffix == ".txt":
|
|
178
|
+
# Read each line as one text, skip empty lines
|
|
179
|
+
return [line.strip() for line in path.read_text().splitlines() if line.strip()]
|
|
180
|
+
elif suffix in (".csv", ".tsv"):
|
|
181
|
+
separator = "\t" if suffix == ".tsv" else ","
|
|
182
|
+
df = pd.read_csv(path, sep=separator)
|
|
183
|
+
if column not in df.columns:
|
|
184
|
+
typer.echo(
|
|
185
|
+
f"Error: column {column!r} not found. "
|
|
186
|
+
f"Available: {list(df.columns)}",
|
|
187
|
+
err=True,
|
|
188
|
+
)
|
|
189
|
+
raise typer.Exit(code=1)
|
|
190
|
+
return df[column].astype(str).tolist()
|
|
191
|
+
else:
|
|
192
|
+
typer.echo(f"Error: unsupported file extension {suffix!r}. Use .txt, .csv, or .tsv.", err=True)
|
|
193
|
+
raise typer.Exit(code=1)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _print_results(results: dict[str, float], fmt: str) -> None:
|
|
197
|
+
"""Format and print results."""
|
|
198
|
+
if fmt == "json":
|
|
199
|
+
typer.echo(json.dumps(results, indent=2))
|
|
200
|
+
elif fmt == "csv":
|
|
201
|
+
writer = csv.writer(sys.stdout)
|
|
202
|
+
writer.writerow(["measure", "score"])
|
|
203
|
+
for name, score in results.items():
|
|
204
|
+
writer.writerow([name, score])
|
|
205
|
+
else: # table
|
|
206
|
+
max_name = max(len(n) for n in results) if results else 0
|
|
207
|
+
for name, score in results.items():
|
|
208
|
+
typer.echo(f" {name:<{max_name}} {score:.6f}")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
if __name__ == "__main__":
|
|
212
|
+
app()
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Two-level cached pairwise distance computation.
|
|
3
|
+
|
|
4
|
+
scipy.pdist is the bottleneck when several measures are run on the same
|
|
5
|
+
embedding matrix. This module wraps it with:
|
|
6
|
+
|
|
7
|
+
Level 1 — in-process memory: a small bounded LRU dict keyed by full
|
|
8
|
+
content fingerprint of the matrix plus the metric / kwargs. Up to
|
|
9
|
+
_MEMORY_MAX entries are kept; oldest is evicted on overflow.
|
|
10
|
+
|
|
11
|
+
Level 2 — disk: condensed distance array stored under .cache/pdist/
|
|
12
|
+
as safetensors, keyed the same way. Survives across processes — a
|
|
13
|
+
SLURM job that finished yesterday leaves a cache that today's job
|
|
14
|
+
can pick up.
|
|
15
|
+
|
|
16
|
+
Level 3 — compute: scipy.pdist + write through both layers.
|
|
17
|
+
|
|
18
|
+
The cache key folds in the metric and any metric_kwargs, so different
|
|
19
|
+
metrics on the same data do not collide.
|
|
20
|
+
"""
|
|
21
|
+
from collections import OrderedDict
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any, Callable, Sequence, Union
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import xxhash
|
|
27
|
+
from scipy.spatial.distance import pdist
|
|
28
|
+
from safetensors.numpy import save_file, load_file
|
|
29
|
+
|
|
30
|
+
DISTANCE_METRIC = Union[str, Callable[[np.ndarray, np.ndarray], float]]
|
|
31
|
+
DEFAULT_CACHE_DIR = Path(".cache/pdist")
|
|
32
|
+
# how many chunks we feed into the hash function at a time, to keep memory
|
|
33
|
+
# usage constant regardless of input size
|
|
34
|
+
_HASH_CHUNK = 1_000_000
|
|
35
|
+
# how many distance matrices to keep in memory before evicting the oldest one (LRU)
|
|
36
|
+
_MEMORY_MAX = 4
|
|
37
|
+
|
|
38
|
+
# in-memory cache (LRU)
|
|
39
|
+
_memory: "OrderedDict[str, np.ndarray]" = OrderedDict()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _fingerprint(X: np.ndarray) -> str:
|
|
43
|
+
"""Full-content xxhash of an array, chunked to keep memory constant."""
|
|
44
|
+
h = xxhash.xxh64()
|
|
45
|
+
h.update(str(X.shape).encode())
|
|
46
|
+
h.update(str(X.dtype).encode())
|
|
47
|
+
flat = X.ravel()
|
|
48
|
+
for i in range(0, len(flat), _HASH_CHUNK):
|
|
49
|
+
h.update(flat[i:i + _HASH_CHUNK].tobytes())
|
|
50
|
+
return h.hexdigest()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _metric_key(metric: DISTANCE_METRIC, metric_kwargs: dict) -> str:
|
|
54
|
+
"""Stable, filesystem-safe key for metric + kwargs."""
|
|
55
|
+
if not metric_kwargs and isinstance(metric, str):
|
|
56
|
+
return metric
|
|
57
|
+
parts = [str(metric)]
|
|
58
|
+
for k in sorted(metric_kwargs):
|
|
59
|
+
parts.append(f"{k}={metric_kwargs[k]!r}")
|
|
60
|
+
return xxhash.xxh64("|".join(parts).encode()).hexdigest()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _store_memory(key: str, result: np.ndarray) -> None:
|
|
64
|
+
if _MEMORY_MAX <= 0:
|
|
65
|
+
return
|
|
66
|
+
if key in _memory:
|
|
67
|
+
_memory.move_to_end(key)
|
|
68
|
+
return
|
|
69
|
+
if len(_memory) >= _MEMORY_MAX:
|
|
70
|
+
_memory.popitem(last=False)
|
|
71
|
+
_memory[key] = result
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def compute_pairwise_distances(
|
|
75
|
+
data: Sequence[Sequence[float]],
|
|
76
|
+
metric: DISTANCE_METRIC = "cosine",
|
|
77
|
+
cache_dir: Path = DEFAULT_CACHE_DIR,
|
|
78
|
+
**metric_kwargs: Any,
|
|
79
|
+
) -> np.ndarray:
|
|
80
|
+
"""
|
|
81
|
+
Compute pairwise distances with two-level (memory + disk) caching.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
data: 2D array-like of shape (n_samples, n_features).
|
|
85
|
+
metric: Distance metric name (e.g. "cosine", "euclidean") or callable.
|
|
86
|
+
cache_dir: Root directory for the disk cache.
|
|
87
|
+
**metric_kwargs: Extra keyword arguments forwarded to scipy.pdist.
|
|
88
|
+
Included in the cache key so different kwargs do not collide.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Condensed distance array (upper triangle from scipy.pdist).
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
ValueError: If data is empty or single row.
|
|
95
|
+
"""
|
|
96
|
+
X = np.asarray(data, dtype=float)
|
|
97
|
+
n = X.shape[0]
|
|
98
|
+
if n == 0:
|
|
99
|
+
raise ValueError("Cannot compute distances for empty data")
|
|
100
|
+
if n == 1:
|
|
101
|
+
raise ValueError("Cannot compute distances for single data point")
|
|
102
|
+
|
|
103
|
+
metric_id = _metric_key(metric, metric_kwargs)
|
|
104
|
+
fp = _fingerprint(X)
|
|
105
|
+
key = f"{fp}|{metric_id}"
|
|
106
|
+
|
|
107
|
+
# Level 1: in-memory match by content fingerprint
|
|
108
|
+
if key in _memory:
|
|
109
|
+
_memory.move_to_end(key)
|
|
110
|
+
return _memory[key]
|
|
111
|
+
|
|
112
|
+
# Level 2: disk
|
|
113
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
114
|
+
path = cache_dir / f"{fp}_{metric_id}.safetensors"
|
|
115
|
+
if path.exists():
|
|
116
|
+
result = load_file(path)["distances"]
|
|
117
|
+
_store_memory(key, result)
|
|
118
|
+
return result
|
|
119
|
+
|
|
120
|
+
# Level 3: compute, populate both layers
|
|
121
|
+
result = pdist(X, metric=metric, **metric_kwargs)
|
|
122
|
+
_store_memory(key, result)
|
|
123
|
+
save_file({"distances": result}, path)
|
|
124
|
+
return result
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def clear_distance_cache(cache_dir: Path = DEFAULT_CACHE_DIR) -> None:
|
|
128
|
+
"""Clear both memory and disk caches."""
|
|
129
|
+
import shutil
|
|
130
|
+
_memory.clear()
|
|
131
|
+
if cache_dir.exists():
|
|
132
|
+
shutil.rmtree(cache_dir)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def distance_cache_info(cache_dir: Path = DEFAULT_CACHE_DIR) -> dict:
|
|
136
|
+
"""Return memory and disk cache statistics."""
|
|
137
|
+
disk_files = list(cache_dir.glob("*.safetensors")) if cache_dir.exists() else []
|
|
138
|
+
return {
|
|
139
|
+
"memory_entries": len(_memory),
|
|
140
|
+
"memory_mb": round(sum(v.nbytes for v in _memory.values()) / 1024 / 1024, 2),
|
|
141
|
+
"memory_max": _MEMORY_MAX,
|
|
142
|
+
"disk_files": len(disk_files),
|
|
143
|
+
"disk_mb": round(sum(f.stat().st_size for f in disk_files) / 1024 / 1024, 2),
|
|
144
|
+
}
|