kgmodule-utils 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kgmodule-utils
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Shared types and snapshot infrastructure for the KGModule SDK
5
5
  License: Elastic-2.0
6
6
  License-File: LICENSE
@@ -10,7 +10,7 @@ build-backend = "poetry.core.masonry.api"
10
10
 
11
11
  [project]
12
12
  name = "kgmodule-utils"
13
- version = "0.2.0"
13
+ version = "0.2.2"
14
14
  description = "Shared types and snapshot infrastructure for the KGModule SDK"
15
15
  readme = "README.md"
16
16
  license = { text = "Elastic-2.0" }
@@ -5,6 +5,8 @@ Sub-packages:
5
5
  kg_utils.snapshots — Snapshot, SnapshotManager, SnapshotManifest, etc.
6
6
  kg_utils.embed — Embedder protocol, DEFAULT_MODEL, KNOWN_MODELS,
7
7
  kg_model_cache_dir(), resolve_model_path().
8
+ kg_utils.embedder — Concrete SentenceTransformerEmbedder, get_embedder(),
9
+ wrap_embedder(), load_sentence_transformer().
8
10
  """
9
11
 
10
- __version__ = "0.1.0"
12
+ __version__ = "0.2.2"
@@ -0,0 +1,233 @@
1
+ """kg_utils.embedder — Concrete SentenceTransformer embedding for the KGModule stack.
2
+
3
+ All model-loading logic lives here so that the ``local_files_only`` guard,
4
+ KNOWN_MODELS alias resolution, and path convention are defined exactly once.
5
+ Every KG module (doc_kg, diary_kg, code_kg, …) imports from here instead of
6
+ reimplementing the load sequence.
7
+
8
+ Contents
9
+ --------
10
+ Embedder
11
+ Abstract base class with ``embed_texts`` + ``embed_query`` + ``dim``.
12
+
13
+ SentenceTransformerEmbedder
14
+ Concrete implementation. Always uses ``local_files_only=True`` when the
15
+ model is cached locally — prevents HuggingFace HEAD requests that leave
16
+ stale thread/network state and cause SIGBUS on MPS.
17
+
18
+ load_sentence_transformer(model_name)
19
+ Raw ``SentenceTransformer`` factory with the canonical safe-load sequence.
20
+ Use when you need the bare model object (e.g. multi-process workers that
21
+ each load their own copy by name).
22
+
23
+ get_embedder(model_name)
24
+ High-level factory returning a ready-to-use ``SentenceTransformerEmbedder``.
25
+
26
+ wrap_embedder(st_model, model_name)
27
+ Wrap an already-loaded ``SentenceTransformer`` as an ``Embedder``. Use
28
+ this to share a live model between pipeline stages (e.g. DiaryTransformer
29
+ → DocKG) without loading a second copy on MPS/CUDA.
30
+
31
+ Author: Eric G. Suchanek, PhD
32
+ License: Elastic 2.0
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ import os
38
+ from typing import Any
39
+
40
+ from kg_utils.embed import DEFAULT_MODEL, KNOWN_MODELS, resolve_model_path
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Abstract base
45
+ # ---------------------------------------------------------------------------
46
+
47
+
48
+ class Embedder:
49
+ """Abstract embedding backend for the KGModule stack.
50
+
51
+ :param dim: Embedding dimension — set by concrete ``__init__``.
52
+ """
53
+
54
+ dim: int
55
+
56
+ def embed_texts(self, texts: list[str]) -> list[list[float]]:
57
+ """Embed a list of strings into float32 vectors.
58
+
59
+ :param texts: Input strings.
60
+ :return: One float32 vector per input.
61
+ """
62
+ raise NotImplementedError
63
+
64
+ def embed_query(self, query: str) -> list[float]:
65
+ """Embed a single query string.
66
+
67
+ :param query: Query string.
68
+ :return: Float32 vector.
69
+ """
70
+ return self.embed_texts([query])[0]
71
+
72
+
73
+ # ---------------------------------------------------------------------------
74
+ # Canonical model loader
75
+ # ---------------------------------------------------------------------------
76
+
77
+
78
+ def load_sentence_transformer(model_name: str = DEFAULT_MODEL) -> Any:
79
+ """Load a ``SentenceTransformer`` with the canonical safe-load sequence.
80
+
81
+ Resolution order:
82
+
83
+ 1. Resolve KNOWN_MODELS alias → HuggingFace repo ID.
84
+ 2. If ``resolve_model_path()`` returns an existing directory, load from
85
+ the local path with ``local_files_only=True`` — no HF HEAD requests.
86
+ 3. Otherwise try ``local_files_only=True`` (hits HF's own cache layout).
87
+ 4. Fall back to a live network fetch only if the model is genuinely absent.
88
+
89
+ The ``local_files_only=True`` guard on step 2 is critical on MPS: HF HEAD
90
+ retry loops leave stale thread state that causes SIGBUS on the first
91
+ ``encode()`` call.
92
+
93
+ :param model_name: HuggingFace model ID or KNOWN_MODELS alias.
94
+ :return: Loaded ``SentenceTransformer`` instance.
95
+ """
96
+ from sentence_transformers import SentenceTransformer # pylint: disable=import-outside-toplevel
97
+
98
+ resolved = KNOWN_MODELS.get(model_name, model_name)
99
+ trust_remote = "nomic-ai/" in resolved
100
+ local_path = resolve_model_path(resolved)
101
+
102
+ if local_path.exists():
103
+ return SentenceTransformer(
104
+ str(local_path),
105
+ local_files_only=True,
106
+ trust_remote_code=trust_remote,
107
+ )
108
+ try:
109
+ return SentenceTransformer(
110
+ resolved,
111
+ local_files_only=True,
112
+ trust_remote_code=trust_remote,
113
+ )
114
+ except OSError:
115
+ return SentenceTransformer(resolved, trust_remote_code=trust_remote)
116
+
117
+
118
+ # ---------------------------------------------------------------------------
119
+ # Concrete embedder
120
+ # ---------------------------------------------------------------------------
121
+
122
+
123
+ class SentenceTransformerEmbedder(Embedder):
124
+ """Concrete embedder backed by ``sentence-transformers``.
125
+
126
+ Delegates model loading to :func:`load_sentence_transformer` so the
127
+ ``local_files_only`` guard is always in effect.
128
+
129
+ :param model_name: HuggingFace model ID or KNOWN_MODELS alias.
130
+ """
131
+
132
+ def __init__(self, model_name: str = DEFAULT_MODEL) -> None:
133
+ try:
134
+ from transformers import logging as hf_logging # pylint: disable=import-outside-toplevel
135
+
136
+ hf_logging.set_verbosity_error() # type: ignore[no-untyped-call]
137
+ except ImportError:
138
+ pass
139
+
140
+ _prev = os.environ.get("TQDM_DISABLE")
141
+ os.environ["TQDM_DISABLE"] = "1"
142
+ try:
143
+ self.model = load_sentence_transformer(model_name)
144
+ finally:
145
+ if _prev is None:
146
+ os.environ.pop("TQDM_DISABLE", None)
147
+ else:
148
+ os.environ["TQDM_DISABLE"] = _prev
149
+
150
+ self.model_name: str = KNOWN_MODELS.get(model_name, model_name)
151
+ # ST ≥5.4 renamed to get_embedding_dimension; ≤5.3 only had get_sentence_embedding_dimension.
152
+ _dim_fn = getattr(self.model, "get_embedding_dimension", None) or getattr(
153
+ self.model, "get_sentence_embedding_dimension", None
154
+ )
155
+ self.dim: int = (_dim_fn() if _dim_fn is not None else None) or 384
156
+
157
+ def embed_texts(self, texts: list[str], encode_batch_size: int = 512) -> list[list[float]]:
158
+ """Embed a list of strings into float32 vectors.
159
+
160
+ :param texts: Input strings.
161
+ :param encode_batch_size: Passed to ``model.encode()`` — tune down if OOM on MPS.
162
+ """
163
+ import numpy as np # pylint: disable=import-outside-toplevel
164
+
165
+ vecs = self.model.encode(
166
+ texts,
167
+ batch_size=encode_batch_size,
168
+ normalize_embeddings=True,
169
+ show_progress_bar=False,
170
+ )
171
+ return [np.asarray(v, dtype="float32").tolist() for v in vecs]
172
+
173
+ def embed_query(self, query: str) -> list[float]:
174
+ """Embed a single query string into a float32 vector."""
175
+ import numpy as np # pylint: disable=import-outside-toplevel
176
+
177
+ vec = self.model.encode([query], normalize_embeddings=True)[0]
178
+ return list(np.asarray(vec, dtype="float32").tolist())
179
+
180
+ def __repr__(self) -> str:
181
+ return f"SentenceTransformerEmbedder(model={self.model_name!r}, dim={self.dim})"
182
+
183
+
184
+ # ---------------------------------------------------------------------------
185
+ # Factory functions
186
+ # ---------------------------------------------------------------------------
187
+
188
+
189
+ def get_embedder(model_name: str = DEFAULT_MODEL) -> SentenceTransformerEmbedder:
190
+ """Return a ready-to-use :class:`SentenceTransformerEmbedder`.
191
+
192
+ :param model_name: HuggingFace model ID or KNOWN_MODELS alias.
193
+ :return: Configured embedder instance.
194
+ """
195
+ return SentenceTransformerEmbedder(model_name)
196
+
197
+
198
+ def wrap_embedder(st_model: Any, model_name: str = DEFAULT_MODEL) -> Embedder:
199
+ """Wrap an already-loaded ``SentenceTransformer`` as an :class:`Embedder`.
200
+
201
+ Use this when a live model is already on the GPU (e.g. DiaryTransformer →
202
+ DocKG handoff) to avoid loading a second copy on MPS/CUDA.
203
+
204
+ :param st_model: Live ``SentenceTransformer`` instance.
205
+ :param model_name: Model name stored as metadata on the wrapper.
206
+ :return: An :class:`Embedder` that delegates all calls to *st_model*.
207
+ """
208
+ import numpy as np # pylint: disable=import-outside-toplevel
209
+
210
+ resolved = KNOWN_MODELS.get(model_name, model_name)
211
+ _dim_fn = getattr(st_model, "get_embedding_dimension", None) or getattr(
212
+ st_model, "get_sentence_embedding_dimension", None
213
+ )
214
+ _dim = (_dim_fn() if _dim_fn is not None else None) or 384
215
+
216
+ class _WrappedEmbedder(Embedder):
217
+ model_name: str = resolved
218
+ dim: int = _dim
219
+
220
+ def embed_texts(self, texts: list[str]) -> list[list[float]]:
221
+ vecs = st_model.encode(
222
+ texts,
223
+ batch_size=512,
224
+ normalize_embeddings=True,
225
+ show_progress_bar=False,
226
+ )
227
+ return [np.asarray(v, dtype="float32").tolist() for v in vecs]
228
+
229
+ def embed_query(self, query: str) -> list[float]:
230
+ vec = st_model.encode([query], normalize_embeddings=True)[0]
231
+ return list(np.asarray(vec, dtype="float32").tolist())
232
+
233
+ return _WrappedEmbedder()
@@ -214,10 +214,13 @@ class SnapshotManager:
214
214
  manifest = SnapshotManifest.from_dict(
215
215
  json.loads(self.manifest_path.read_text(encoding="utf-8"))
216
216
  )
217
- # Normalise legacy 'tree_hash' -> 'key'
217
+ # Normalise legacy key fields -> 'key'
218
218
  for entry in manifest.snapshots:
219
- if "key" not in entry and "tree_hash" in entry:
220
- entry["key"] = entry.pop("tree_hash")
219
+ if "key" not in entry:
220
+ if "tree_hash" in entry:
221
+ entry["key"] = entry.pop("tree_hash")
222
+ elif "commit" in entry:
223
+ entry["key"] = entry["commit"]
221
224
  return manifest
222
225
 
223
226
  def _save_manifest(self, manifest: SnapshotManifest) -> None:
@@ -277,19 +280,23 @@ class SnapshotManager:
277
280
  if not current_ts:
278
281
  return None
279
282
  prev_entry = None
280
- for s in sorted(manifest.snapshots, key=lambda x: x["timestamp"], reverse=True):
281
- if s["timestamp"] < current_ts:
283
+ for s in sorted(manifest.snapshots, key=lambda x: x.get("timestamp", ""), reverse=True):
284
+ if s.get("timestamp", "") < current_ts:
282
285
  prev_entry = s
283
286
  break
284
- return self.load_snapshot(prev_entry["key"]) if prev_entry else None
287
+ if not prev_entry:
288
+ return None
289
+ prev_key = prev_entry.get("key", "")
290
+ return self.load_snapshot(prev_key) if prev_key else None
285
291
 
286
292
  def get_baseline(self) -> Snapshot | None:
287
293
  """Get the oldest snapshot (baseline for comparison)."""
288
294
  manifest = self.load_manifest()
289
295
  if not manifest.snapshots:
290
296
  return None
291
- baseline_entry = min(manifest.snapshots, key=lambda x: x["timestamp"])
292
- return self.load_snapshot(baseline_entry["key"])
297
+ baseline_entry = min(manifest.snapshots, key=lambda x: x.get("timestamp", ""))
298
+ baseline_key = baseline_entry.get("key", "")
299
+ return self.load_snapshot(baseline_key) if baseline_key else None
293
300
 
294
301
  def list_snapshots(
295
302
  self,
File without changes
File without changes