emb-diversity 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. emb_diversity/__init__.py +63 -0
  2. emb_diversity/_accepts_text.py +104 -0
  3. emb_diversity/_registry.py +51 -0
  4. emb_diversity/axes_registry.py +54 -0
  5. emb_diversity/cli.py +212 -0
  6. emb_diversity/compute_pairwise.py +144 -0
  7. emb_diversity/convenience.py +87 -0
  8. emb_diversity/embed.py +45 -0
  9. emb_diversity/embeddings/SBERT.py +38 -0
  10. emb_diversity/embeddings/SimCSE.py +24 -0
  11. emb_diversity/embeddings/__init__.py +0 -0
  12. emb_diversity/embeddings/embed.py +72 -0
  13. emb_diversity/eval/__init__.py +0 -0
  14. emb_diversity/eval/data/STEL.py +76 -0
  15. emb_diversity/eval/data/__init__.py +0 -0
  16. emb_diversity/eval/data/synthstel.py +137 -0
  17. emb_diversity/evaluate_measures.py +260 -0
  18. emb_diversity/measures/__init__.py +0 -0
  19. emb_diversity/measures/_types.py +14 -0
  20. emb_diversity/measures/bins_entropy.py +169 -0
  21. emb_diversity/measures/bottleneck.py +47 -0
  22. emb_diversity/measures/chamfer_dist.py +64 -0
  23. emb_diversity/measures/cluster_inertia.py +59 -0
  24. emb_diversity/measures/convex_hull_volume_2d.py +109 -0
  25. emb_diversity/measures/dcscore.py +100 -0
  26. emb_diversity/measures/diameter.py +37 -0
  27. emb_diversity/measures/dist_dispersion.py +38 -0
  28. emb_diversity/measures/energy.py +50 -0
  29. emb_diversity/measures/graph_entropy.py +81 -0
  30. emb_diversity/measures/hamdiv.py +107 -0
  31. emb_diversity/measures/log_determinant.py +128 -0
  32. emb_diversity/measures/mean_pw_dist.py +37 -0
  33. emb_diversity/measures/mst_dispersion.py +47 -0
  34. emb_diversity/measures/radius.py +52 -0
  35. emb_diversity/measures/renyi_entropy.py +140 -0
  36. emb_diversity/measures/span_centroid.py +55 -0
  37. emb_diversity/measures/span_medoid.py +42 -0
  38. emb_diversity/measures/sum_bottleneck.py +56 -0
  39. emb_diversity/measures/sum_diameter.py +55 -0
  40. emb_diversity/measures/utils.py +27 -0
  41. emb_diversity/measures/vendi_score.py +76 -0
  42. emb_diversity/measures_registry.py +68 -0
  43. emb_diversity/plot/__init__.py +0 -0
  44. emb_diversity/two_d.py +230 -0
  45. emb_diversity/utility/__init__.py +3 -0
  46. emb_diversity/utility/_cache.py +85 -0
  47. emb_diversity/utility/project_root.py +17 -0
  48. emb_diversity-0.0.3.dist-info/METADATA +386 -0
  49. emb_diversity-0.0.3.dist-info/RECORD +53 -0
  50. emb_diversity-0.0.3.dist-info/WHEEL +5 -0
  51. emb_diversity-0.0.3.dist-info/entry_points.txt +2 -0
  52. emb_diversity-0.0.3.dist-info/licenses/LICENSE +21 -0
  53. emb_diversity-0.0.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,63 @@
1
+ """emb-diversity – embedding-based diversity measures for text and vector data."""
2
+
3
+ from __future__ import annotations
4
+
5
+ ### Distance-Based Diversity Measures
6
+ from .measures.mean_pw_dist import mean_pw_dist
7
+ from .measures.dist_dispersion import dist_dispersion
8
+ from .measures.hamdiv import hamdiv
9
+ from .measures.diameter import diameter
10
+ from .measures.bottleneck import bottleneck
11
+ from .measures.sum_bottleneck import sum_bottleneck
12
+ from .measures.sum_diameter import sum_diameter
13
+ from .measures.energy import energy
14
+ from .measures.cluster_inertia import cluster_inertia
15
+ from .measures.span_centroid import span_centroid
16
+ from .measures.chamfer_dist import chamfer_dist
17
+
18
+ ### Volume-Based Diversity Measures
19
+ from .measures.convex_hull_volume_2d import convex_hull_volume_2d
20
+ from .measures.radius import radius
21
+ from .measures.span_medoid import span_medoid
22
+
23
+ ### Distribution-Based Diversity Measures
24
+ from .measures.vendi_score import vendi_score
25
+ from .measures.renyi_entropy import renyi_entropy
26
+ from .measures.dcscore import dcscore
27
+ from .measures.log_determinant import log_determinant
28
+ from .measures.bins_entropy import bins_entropy
29
+
30
+ ### Graph-Based Diversity Measures
31
+ from .measures.graph_entropy import graph_entropy
32
+ from .measures.mst_dispersion import mst_dispersion
33
+
34
+ ### Registries
35
+ from .axes_registry import axes
36
+ from .measures_registry import measures
37
+
38
+ ### Embedding helper
39
+ from .embed import embed_texts
40
+
41
+ ### Main entry point
42
+ from .convenience import measure_diversity
43
+
44
+ ### Caching utilities
45
+ from .compute_pairwise import compute_pairwise_distances, clear_distance_cache, distance_cache_info
46
+
47
+
48
+ __all__ = [
49
+ # Main entry point
50
+ "measure_diversity",
51
+ # Individual measures
52
+ "mean_pw_dist", "dist_dispersion", "hamdiv", "diameter", "bottleneck",
53
+ "sum_bottleneck", "sum_diameter", "energy", "cluster_inertia", "span_centroid", "chamfer_dist",
54
+ "convex_hull_volume_2d", "radius", "span_medoid", "vendi_score",
55
+ "renyi_entropy", "dcscore", "log_determinant", "bins_entropy",
56
+ "graph_entropy", "mst_dispersion",
57
+ # Helpers
58
+ "embed_texts",
59
+ # Registries
60
+ "axes", "measures",
61
+ # Pairwise distance caching
62
+ "compute_pairwise_distances", "clear_distance_cache", "distance_cache_info",
63
+ ]
@@ -0,0 +1,104 @@
1
+ """Decorator that lets measure functions accept raw text.
2
+
3
+ When a measure function is decorated with ``@accepts_text``, it gains
4
+ the ability to receive a list of strings instead of embeddings. The
5
+ decorator detects text input, embeds it via :func:`emb_diversity.embed.embed_texts`,
6
+ and passes the resulting vectors to the original measure function.
7
+
8
+ Two keyword arguments are added to the decorated function:
9
+
10
+ - ``diversity_axis`` (default ``"semantic"``): which axis to use for embedding
11
+ - ``embedding_model``: explicit model name, overrides the axis
12
+
13
+ If the input is already numeric (e.g. a numpy array), the decorator
14
+ passes it through unchanged.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import functools
20
+ import inspect
21
+ from typing import Sequence
22
+
23
+ from .embed import embed_texts
24
+
25
+
26
+ # ── Input detection ──────────────────────────────────────────────────
27
+
28
+
29
+ def _is_text_input(data):
30
+ """Return True if *data* looks like a list of strings."""
31
+ return len(data) > 0 and isinstance(data[0], str)
32
+
33
+
34
+ # ── Decorator ────────────────────────────────────────────────────────
35
+
36
+
37
+ def accepts_text(func):
38
+ """Wrap a measure function so it can receive raw text.
39
+
40
+ The decorated function gains two optional keyword arguments:
41
+
42
+ * ``diversity_axis`` – registered axis name (default ``"semantic"``)
43
+ * ``embedding_model`` – explicit model id (overrides axis)
44
+
45
+ When the first positional argument (*data*) is a list of strings the
46
+ decorator embeds them before calling the original measure. When *data*
47
+ is already numeric the decorator passes it through unchanged.
48
+ """
49
+
50
+ @functools.wraps(func)
51
+ def wrapper(data, *args, diversity_axis="semantic", embedding_model=None, **kwargs):
52
+ if _is_text_input(data):
53
+ data = embed_texts(
54
+ data,
55
+ diversity_axis=diversity_axis,
56
+ embedding_model=embedding_model,
57
+ )
58
+ return func(data, *args, **kwargs)
59
+
60
+ # Preserve the original signature but add the new params for docs/IDE
61
+ _patch_signature(wrapper, func)
62
+ return wrapper
63
+
64
+
65
+ def _patch_signature(wrapper, func):
66
+ """Fix the wrapper's signature so help() and IDEs show useful parameter names.
67
+
68
+ This is purely cosmetic — it only affects what help() and IDEs display,
69
+ not how the function actually runs.
70
+
71
+ Example for mean_pw_dist:
72
+ Before: mean_pw_dist(data, *args, diversity_axis="semantic", embedding_model=None, **kwargs)
73
+ After: mean_pw_dist(data, metric="cosine", diversity_axis="semantic", embedding_model=None, **metric_kwargs)
74
+ """
75
+ # The two parameters that the decorator adds
76
+ diversity_axis_param = inspect.Parameter(
77
+ "diversity_axis", inspect.Parameter.KEYWORD_ONLY, default="semantic",
78
+ )
79
+ embedding_model_param = inspect.Parameter(
80
+ "embedding_model", inspect.Parameter.KEYWORD_ONLY, default=None,
81
+ )
82
+ new_params = [diversity_axis_param, embedding_model_param]
83
+
84
+ # Read the original function's parameters
85
+ # e.g. for mean_pw_dist: [data, metric="cosine", **metric_kwargs]
86
+ original_sig = inspect.signature(func)
87
+ original_params = list(original_sig.parameters.values())
88
+
89
+ # Find the **kwargs parameter if it exists (e.g. **metric_kwargs)
90
+ kwargs_params = [p for p in original_params if p.kind == inspect.Parameter.VAR_KEYWORD]
91
+ has_kwargs = len(kwargs_params) > 0
92
+
93
+ # Build the combined parameter list:
94
+ # original params + new params + **kwargs at the end
95
+ # We insert before **kwargs so the signature reads naturally:
96
+ # (data, metric, diversity_axis, embedding_model, **metric_kwargs)
97
+ if has_kwargs:
98
+ kwargs_position = original_params.index(kwargs_params[0])
99
+ params_before_kwargs = original_params[:kwargs_position]
100
+ combined = params_before_kwargs + new_params + [kwargs_params[0]]
101
+ else:
102
+ combined = original_params + new_params
103
+
104
+ wrapper.__signature__ = original_sig.replace(parameters=combined)
@@ -0,0 +1,51 @@
1
+ """Generic registry pattern for O(1) key-value lookup.
2
+
3
+ Based on https://dev.to/dentedlogic/stop-writing-giant-if-else-chains-master-the-python-registry-pattern-ldm
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ class Registry:
9
+ """A simple key-value store with register, get, and list operations.
10
+
11
+ Used by both the axes registry and the measures registry to avoid
12
+ repeated dictionary boilerplate.
13
+ """
14
+
15
+ def __init__(self):
16
+ self._store: dict[str, object] = {}
17
+
18
+ def register(self, key: str, value: object) -> None:
19
+ """Add an entry. Overwrites if key already exists."""
20
+ self._store[key] = value
21
+
22
+ def get(self, key: str) -> object:
23
+ """Look up by key.
24
+
25
+ Raises:
26
+ KeyError: If the key has not been registered.
27
+ """
28
+ if key not in self._store:
29
+ registered = ", ".join(sorted(self._store)) or "(none)"
30
+ raise KeyError(f"Unknown key {key!r}. Registered: {registered}")
31
+ return self._store[key]
32
+
33
+ def list_all(self) -> list[object]:
34
+ """Return all values sorted by key."""
35
+ return [self._store[k] for k in sorted(self._store)]
36
+
37
+ def keys(self) -> list[str]:
38
+ """Return all registered keys."""
39
+ return list(self._store.keys())
40
+
41
+ def __contains__(self, key: str) -> bool:
42
+ return key in self._store
43
+
44
+ def __len__(self) -> int:
45
+ return len(self._store)
46
+
47
+ def __getitem__(self, key: str) -> object:
48
+ return self.get(key)
49
+
50
+ def __iter__(self):
51
+ return iter(self._store)
@@ -0,0 +1,54 @@
1
+ """Diversity axes registry.
2
+
3
+ A diversity axis maps a concept (e.g. "semantic", "style") to a default
4
+ embedding model and optional alternatives.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+
11
+ from ._registry import Registry
12
+
13
+
14
+ @dataclass
15
+ class DiversityAxis:
16
+ """Configuration for a diversity axis.
17
+
18
+ Attributes:
19
+ name: Short identifier (e.g. ``"semantic"``).
20
+ default_model: HuggingFace model id used by default for this axis.
21
+ alternative_models: Other models that work well for this axis.
22
+ description: Human-readable explanation shown in docs and CLI.
23
+ """
24
+
25
+ name: str
26
+ default_model: str
27
+ alternative_models: list[str] = field(default_factory=list)
28
+ description: str = ""
29
+
30
+
31
+ # Module-level registry instance
32
+ axes = Registry()
33
+
34
+ # ── Built-in axes ────────────────────────────────────────────────────
35
+
36
+ axes.register(
37
+ "semantic",
38
+ DiversityAxis(
39
+ name="semantic",
40
+ default_model="all-mpnet-base-v2",
41
+ alternative_models=["all-MiniLM-L6-v2"],
42
+ description="Meaning-based diversity using semantic similarity",
43
+ ),
44
+ )
45
+
46
+ axes.register(
47
+ "style",
48
+ DiversityAxis(
49
+ name="style",
50
+ default_model="AnnaWegmann/Style-Embedding",
51
+ alternative_models=["StyleDistance/styledistance", "rrivera1849/LUAR-MUD", "AIDA-UPM/star"],
52
+ description="Writing style diversity",
53
+ ),
54
+ )
emb_diversity/cli.py ADDED
@@ -0,0 +1,212 @@
1
+ """Command-line interface for emb-diversity."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ import json
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ import typer
12
+
13
+ app = typer.Typer(
14
+ name="emb-diversity",
15
+ help="Measure embedding-based diversity of text data.\n\nRun directly: emb-diversity texts.txt",
16
+ no_args_is_help=True,
17
+ )
18
+
19
+
20
+ @app.command(hidden=True)
21
+ def default(
22
+ input_file: Path = typer.Argument(
23
+ ..., help="Path to input file (.txt, .csv, or .tsv)"
24
+ ),
25
+ measures: Optional[list[str]] = typer.Option(
26
+ None,
27
+ "--measure",
28
+ "-m",
29
+ help=(
30
+ "Measure to run. Use -m name for one, repeat for several: "
31
+ "-m mean_pw_dist -m diameter. "
32
+ "Use -m variety|balance|disparity for a named set, -m all for all. "
33
+ "Default: graph_entropy."
34
+ ),
35
+ ),
36
+ axis: str = typer.Option(
37
+ "semantic", "--axis", "-a", help="Diversity axis (e.g. semantic, style)."
38
+ ),
39
+ model: Optional[str] = typer.Option(
40
+ None, "--model", help="Explicit embedding model; overrides --axis."
41
+ ),
42
+ column: str = typer.Option(
43
+ "text", "--column", "-c", help="Column name containing text (for CSV/TSV)."
44
+ ),
45
+ output_format: str = typer.Option(
46
+ "table", "--format", "-f", help="Output format: table, json, csv."
47
+ ),
48
+ ) -> None:
49
+ """Measure diversity from a text file or CSV/TSV."""
50
+ _run_measure(input_file, measures, axis, model, column, output_format)
51
+
52
+
53
+ @app.command("measure")
54
+ def measure_cmd(
55
+ input_file: Path = typer.Argument(
56
+ ..., help="Path to input file (.txt, .csv, or .tsv)"
57
+ ),
58
+ measures: Optional[list[str]] = typer.Option(
59
+ None,
60
+ "--measure",
61
+ "-m",
62
+ help=(
63
+ "Measure to run. Use -m name for one, repeat for several: "
64
+ "-m mean_pw_dist -m diameter. "
65
+ "Use -m variety|balance|disparity for a named set, -m all for all. "
66
+ "Default: graph_entropy."
67
+ ),
68
+ ),
69
+ axis: str = typer.Option(
70
+ "semantic", "--axis", "-a", help="Diversity axis (e.g. semantic, style)."
71
+ ),
72
+ model: Optional[str] = typer.Option(
73
+ None, "--model", help="Explicit embedding model; overrides --axis."
74
+ ),
75
+ column: str = typer.Option(
76
+ "text", "--column", "-c", help="Column name containing text (for CSV/TSV)."
77
+ ),
78
+ output_format: str = typer.Option(
79
+ "table", "--format", "-f", help="Output format: table, json, csv."
80
+ ),
81
+ ) -> None:
82
+ """Measure diversity from a text file or CSV/TSV."""
83
+ _run_measure(input_file, measures, axis, model, column, output_format)
84
+
85
+
86
+ def _run_measure(input_file, measures, axis, model, column, output_format):
87
+ """Shared logic for the measure command."""
88
+ from .convenience import measure_diversity
89
+
90
+ # ── Read texts ───────────────────────────────────────────────
91
+ texts = _read_texts(input_file, column)
92
+ if len(texts) < 2:
93
+ typer.echo("Error: need at least 2 texts to measure diversity.", err=True)
94
+ raise typer.Exit(code=1)
95
+
96
+ # ── Convert Typer's list format to measure_diversity()'s format ──
97
+ # Typer always gives a list (e.g. ["all"]), but measure_diversity()
98
+ # expects a plain string for "all" and named-set shortcuts.
99
+ from .measures_registry import MEASURE_SETS
100
+
101
+ if measures is None:
102
+ measure_arg = None
103
+ elif len(measures) == 1 and measures[0] in ("all", *MEASURE_SETS):
104
+ measure_arg = measures[0]
105
+ else:
106
+ measure_arg = measures
107
+
108
+ # ── Compute ──────────────────────────────────────────────────
109
+ typer.echo(f"Measuring diversity of {len(texts)} texts...", err=True)
110
+ try:
111
+ results = measure_diversity(
112
+ texts, measure=measure_arg, diversity_axis=axis, embedding_model=model,
113
+ )
114
+ except KeyError as exc:
115
+ # measure_diversity() raises KeyError for unknown measure names
116
+ typer.echo(f"Error: {exc}. Run 'emb-diversity list-measures' to see available measures.", err=True)
117
+ raise typer.Exit(code=1)
118
+
119
+ # ── Output ───────────────────────────────────────────────────
120
+ _print_results(results, output_format)
121
+
122
+
123
+ @app.command("list-measures")
124
+ def list_measures_cmd() -> None:
125
+ """List all available diversity measures."""
126
+ from .measures_registry import DEFAULT_MEASURE, MEASURE_SETS, measures
127
+
128
+ for name in sorted(measures):
129
+ tags = []
130
+ if name in DEFAULT_MEASURE:
131
+ tags.append("default")
132
+ for set_name, members in MEASURE_SETS.items():
133
+ if name in members:
134
+ tags.append(set_name)
135
+ suffix = f" [{', '.join(tags)}]" if tags else ""
136
+ typer.echo(f" {name}{suffix}")
137
+
138
+
139
+ @app.command("list-axes")
140
+ def list_axes_cmd() -> None:
141
+ """List registered diversity axes and their models."""
142
+ from .axes_registry import axes
143
+
144
+ for ax in axes.list_all():
145
+ typer.echo(f" {ax.name}")
146
+ typer.echo(f" default model: {ax.default_model}")
147
+ if ax.alternative_models:
148
+ typer.echo(f" alternatives: {', '.join(ax.alternative_models)}")
149
+ if ax.description:
150
+ typer.echo(f" {ax.description}")
151
+ typer.echo()
152
+
153
+
154
+ # ── Helpers ──────────────────────────────────────────────────────
155
+
156
+
157
+ def _read_texts(path: Path, column: str) -> list[str]:
158
+ """Read texts from a file, returning a list of non-empty strings.
159
+
160
+ Supported formats:
161
+ - .txt: one text per line, empty lines are skipped
162
+ - .csv: comma-separated, reads the column specified by `column`
163
+ - .tsv: tab-separated, reads the column specified by `column`
164
+
165
+ For CSV/TSV, rows with empty values in the text column are dropped.
166
+
167
+ Args:
168
+ path: Path to the input file.
169
+ column: Column name to read from CSV/TSV files (default "text").
170
+
171
+ Returns:
172
+ List of text strings, stripped of leading/trailing whitespace.
173
+ """
174
+ import pandas as pd
175
+
176
+ suffix = path.suffix.lower()
177
+ if suffix == ".txt":
178
+ # Read each line as one text, skip empty lines
179
+ return [line.strip() for line in path.read_text().splitlines() if line.strip()]
180
+ elif suffix in (".csv", ".tsv"):
181
+ separator = "\t" if suffix == ".tsv" else ","
182
+ df = pd.read_csv(path, sep=separator)
183
+ if column not in df.columns:
184
+ typer.echo(
185
+ f"Error: column {column!r} not found. "
186
+ f"Available: {list(df.columns)}",
187
+ err=True,
188
+ )
189
+ raise typer.Exit(code=1)
190
+ return df[column].astype(str).tolist()
191
+ else:
192
+ typer.echo(f"Error: unsupported file extension {suffix!r}. Use .txt, .csv, or .tsv.", err=True)
193
+ raise typer.Exit(code=1)
194
+
195
+
196
+ def _print_results(results: dict[str, float], fmt: str) -> None:
197
+ """Format and print results."""
198
+ if fmt == "json":
199
+ typer.echo(json.dumps(results, indent=2))
200
+ elif fmt == "csv":
201
+ writer = csv.writer(sys.stdout)
202
+ writer.writerow(["measure", "score"])
203
+ for name, score in results.items():
204
+ writer.writerow([name, score])
205
+ else: # table
206
+ max_name = max(len(n) for n in results) if results else 0
207
+ for name, score in results.items():
208
+ typer.echo(f" {name:<{max_name}} {score:.6f}")
209
+
210
+
211
+ if __name__ == "__main__":
212
+ app()
@@ -0,0 +1,144 @@
1
+ """
2
+ Two-level cached pairwise distance computation.
3
+
4
+ scipy.pdist is the bottleneck when several measures are run on the same
5
+ embedding matrix. This module wraps it with:
6
+
7
+ Level 1 — in-process memory: a small bounded LRU dict keyed by full
8
+ content fingerprint of the matrix plus the metric / kwargs. Up to
9
+ _MEMORY_MAX entries are kept; oldest is evicted on overflow.
10
+
11
+ Level 2 — disk: condensed distance array stored under .cache/pdist/
12
+ as safetensors, keyed the same way. Survives across processes — a
13
+ SLURM job that finished yesterday leaves a cache that today's job
14
+ can pick up.
15
+
16
+ Level 3 — compute: scipy.pdist + write through both layers.
17
+
18
+ The cache key folds in the metric and any metric_kwargs, so different
19
+ metrics on the same data do not collide.
20
+ """
21
+ from collections import OrderedDict
22
+ from pathlib import Path
23
+ from typing import Any, Callable, Sequence, Union
24
+
25
+ import numpy as np
26
+ import xxhash
27
+ from scipy.spatial.distance import pdist
28
+ from safetensors.numpy import save_file, load_file
29
+
30
+ DISTANCE_METRIC = Union[str, Callable[[np.ndarray, np.ndarray], float]]
31
+ DEFAULT_CACHE_DIR = Path(".cache/pdist")
32
+ # how many chunks we feed into the hash function at a time, to keep memory
33
+ # usage constant regardless of input size
34
+ _HASH_CHUNK = 1_000_000
35
+ # how many distance matrices to keep in memory before evicting the oldest one (LRU)
36
+ _MEMORY_MAX = 4
37
+
38
+ # in-memory cache (LRU)
39
+ _memory: "OrderedDict[str, np.ndarray]" = OrderedDict()
40
+
41
+
42
+ def _fingerprint(X: np.ndarray) -> str:
43
+ """Full-content xxhash of an array, chunked to keep memory constant."""
44
+ h = xxhash.xxh64()
45
+ h.update(str(X.shape).encode())
46
+ h.update(str(X.dtype).encode())
47
+ flat = X.ravel()
48
+ for i in range(0, len(flat), _HASH_CHUNK):
49
+ h.update(flat[i:i + _HASH_CHUNK].tobytes())
50
+ return h.hexdigest()
51
+
52
+
53
+ def _metric_key(metric: DISTANCE_METRIC, metric_kwargs: dict) -> str:
54
+ """Stable, filesystem-safe key for metric + kwargs."""
55
+ if not metric_kwargs and isinstance(metric, str):
56
+ return metric
57
+ parts = [str(metric)]
58
+ for k in sorted(metric_kwargs):
59
+ parts.append(f"{k}={metric_kwargs[k]!r}")
60
+ return xxhash.xxh64("|".join(parts).encode()).hexdigest()
61
+
62
+
63
+ def _store_memory(key: str, result: np.ndarray) -> None:
64
+ if _MEMORY_MAX <= 0:
65
+ return
66
+ if key in _memory:
67
+ _memory.move_to_end(key)
68
+ return
69
+ if len(_memory) >= _MEMORY_MAX:
70
+ _memory.popitem(last=False)
71
+ _memory[key] = result
72
+
73
+
74
+ def compute_pairwise_distances(
75
+ data: Sequence[Sequence[float]],
76
+ metric: DISTANCE_METRIC = "cosine",
77
+ cache_dir: Path = DEFAULT_CACHE_DIR,
78
+ **metric_kwargs: Any,
79
+ ) -> np.ndarray:
80
+ """
81
+ Compute pairwise distances with two-level (memory + disk) caching.
82
+
83
+ Args:
84
+ data: 2D array-like of shape (n_samples, n_features).
85
+ metric: Distance metric name (e.g. "cosine", "euclidean") or callable.
86
+ cache_dir: Root directory for the disk cache.
87
+ **metric_kwargs: Extra keyword arguments forwarded to scipy.pdist.
88
+ Included in the cache key so different kwargs do not collide.
89
+
90
+ Returns:
91
+ Condensed distance array (upper triangle from scipy.pdist).
92
+
93
+ Raises:
94
+ ValueError: If data is empty or single row.
95
+ """
96
+ X = np.asarray(data, dtype=float)
97
+ n = X.shape[0]
98
+ if n == 0:
99
+ raise ValueError("Cannot compute distances for empty data")
100
+ if n == 1:
101
+ raise ValueError("Cannot compute distances for single data point")
102
+
103
+ metric_id = _metric_key(metric, metric_kwargs)
104
+ fp = _fingerprint(X)
105
+ key = f"{fp}|{metric_id}"
106
+
107
+ # Level 1: in-memory match by content fingerprint
108
+ if key in _memory:
109
+ _memory.move_to_end(key)
110
+ return _memory[key]
111
+
112
+ # Level 2: disk
113
+ cache_dir.mkdir(parents=True, exist_ok=True)
114
+ path = cache_dir / f"{fp}_{metric_id}.safetensors"
115
+ if path.exists():
116
+ result = load_file(path)["distances"]
117
+ _store_memory(key, result)
118
+ return result
119
+
120
+ # Level 3: compute, populate both layers
121
+ result = pdist(X, metric=metric, **metric_kwargs)
122
+ _store_memory(key, result)
123
+ save_file({"distances": result}, path)
124
+ return result
125
+
126
+
127
+ def clear_distance_cache(cache_dir: Path = DEFAULT_CACHE_DIR) -> None:
128
+ """Clear both memory and disk caches."""
129
+ import shutil
130
+ _memory.clear()
131
+ if cache_dir.exists():
132
+ shutil.rmtree(cache_dir)
133
+
134
+
135
+ def distance_cache_info(cache_dir: Path = DEFAULT_CACHE_DIR) -> dict:
136
+ """Return memory and disk cache statistics."""
137
+ disk_files = list(cache_dir.glob("*.safetensors")) if cache_dir.exists() else []
138
+ return {
139
+ "memory_entries": len(_memory),
140
+ "memory_mb": round(sum(v.nbytes for v in _memory.values()) / 1024 / 1024, 2),
141
+ "memory_max": _MEMORY_MAX,
142
+ "disk_files": len(disk_files),
143
+ "disk_mb": round(sum(f.stat().st_size for f in disk_files) / 1024 / 1024, 2),
144
+ }