keepsake-memory 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
keepsake/attention.py ADDED
@@ -0,0 +1,146 @@
1
+ """注意力追踪 — 统计用户对各个话题的关注频率。
2
+
3
+ 用户反复提起某个话题 -> 该话题关注度上升 -> 相关碎片在搜索中权重更高。
4
+
5
+ 存储: Redis Sorted Set `keepsake:attention`
6
+ - member: 话题词(由 jieba/关键词提取来)
7
+ - score: 关注度累计值(每次提及 +2,情绪烈度加权)
8
+
9
+ 三套时间窗口(同 hot_topics 模式):
10
+ - 全局: fractured:attention (7天)
11
+ - 日榜: fractured:attention:daily (2天)
12
+ - 周榜: fractured:attention:weekly (14天)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ from typing import Any, Dict, List, Optional
19
+
20
+ import redis
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Redis key 前缀
25
+ ATTENTION_SET = "keepsake:attention"
26
+ ATTENTION_DAILY = "keepsake:attention:daily"
27
+ ATTENTION_WEEKLY = "keepsake:attention:weekly"
28
+
29
+ # 过期时间
30
+ _ATTENTION_TTL = {
31
+ ATTENTION_SET: 86400 * 7, # 全局:7天
32
+ ATTENTION_DAILY: 86400 * 2, # 日榜:2天
33
+ ATTENTION_WEEKLY: 86400 * 14, # 周榜:14天
34
+ }
35
+
36
+
37
+ def record_attention(
38
+ client: redis.Redis,
39
+ text: str,
40
+ emotion_intensity: float = 0.0,
41
+ keywords: Optional[List[str]] = None,
42
+ base_increment: float = 2.0,
43
+ emotion_factor: float = 1.5,
44
+ ) -> None:
45
+ """记录用户对一段文本中话题的关注。"""
46
+ if not client or not text:
47
+ return
48
+
49
+ try:
50
+ # 没有关键词时跳过(正常路径由 store() 传入,不会走到这里)
51
+ if not keywords:
52
+ return
53
+
54
+ increment = base_increment + emotion_intensity * emotion_factor
55
+
56
+ for kw in keywords:
57
+ kw_lower = kw.lower().strip()
58
+ if len(kw_lower) < 2:
59
+ continue
60
+ for topic_set in (ATTENTION_SET, ATTENTION_DAILY, ATTENTION_WEEKLY):
61
+ client.zincrby(topic_set, increment, kw_lower)
62
+ client.expire(topic_set, _ATTENTION_TTL.get(topic_set, 86400))
63
+
64
+ except Exception as e:
65
+ logger.debug("attention: record_attention error: %s", e)
66
+
67
+
68
+ def get_attention_score(client: redis.Redis, keyword: str) -> float:
69
+ """查某个词在当前注意力分数中的排名分。"""
70
+ if not client or not keyword:
71
+ return 0.0
72
+ try:
73
+ score = client.zscore(ATTENTION_SET, keyword.lower().strip())
74
+ return score if score is not None else 0.0
75
+ except Exception:
76
+ return 0.0
77
+
78
+
79
+ def get_top_attention(client: redis.Redis, limit: int = 10,
80
+ period: str = "all") -> List[Dict[str, Any]]:
81
+ """获取关注度最高的词。"""
82
+ key = {
83
+ "all": ATTENTION_SET,
84
+ "daily": ATTENTION_DAILY,
85
+ "weekly": ATTENTION_WEEKLY,
86
+ }.get(period, ATTENTION_SET)
87
+
88
+ if not client:
89
+ return []
90
+ try:
91
+ raw = client.zrevrange(key, 0, limit - 1, withscores=True)
92
+ results = []
93
+ for t, s_raw in raw:
94
+ topic = t.decode("utf-8") if isinstance(t, bytes) else t
95
+ if isinstance(s_raw, bytes):
96
+ s_raw = s_raw.decode("utf-8")
97
+ results.append({"topic": topic, "score": round(float(s_raw), 1)})
98
+ return results
99
+ except Exception as e:
100
+ logger.debug("attention: get_top_attention error: %s", e)
101
+ return []
102
+
103
+
104
+ def match_attention_boost(
105
+ client: redis.Redis,
106
+ content: str,
107
+ top_n: int = 10,
108
+ boost_max: float = 1.5,
109
+ ) -> float:
110
+ """检查碎片内容的注意力关注度加权值。
111
+
112
+ 从全局注意力取 top N 话题,看碎片内容命中几个。
113
+ 命中越多权重越高,最高 boost_max。
114
+ """
115
+ if not client or not content:
116
+ return 1.0
117
+
118
+ try:
119
+ raw = client.zrevrange(ATTENTION_SET, 0, top_n - 1, withscores=True)
120
+ if not raw:
121
+ return 1.0
122
+
123
+ content_lower = content.lower()
124
+ total_score = 0.0
125
+ max_score = 0.0
126
+
127
+ for topic_b, score_raw in raw:
128
+ topic = topic_b.decode("utf-8") if isinstance(topic_b, bytes) else topic_b
129
+ if isinstance(score_raw, bytes):
130
+ score_raw = score_raw.decode("utf-8")
131
+ sc = float(score_raw) # noqa: F841
132
+ if isinstance(sc, (int, float)):
133
+ if len(topic) >= 2 and topic in content_lower:
134
+ total_score += sc
135
+ max_score += sc
136
+
137
+ if max_score <= 0:
138
+ return 1.0
139
+
140
+ # 归一化到 1.0~boost_max,命中越高越接近 boost_max
141
+ ratio = min(total_score / max_score, 1.0)
142
+ return 1.0 + (boost_max - 1.0) * ratio
143
+
144
+ except Exception as e:
145
+ logger.debug("attention: match_attention_boost error: %s", e)
146
+ return 1.0
@@ -0,0 +1,395 @@
1
+ """Consolidation 引擎 — 将同主题碎片分层提炼为更高层记忆。
2
+
3
+ 工作流程:
4
+ 1. 扫描所有未合并的碎片(fragment_type != "consolidated")
5
+ 2. 用 jieba 关键词提取做主题分组
6
+ 3. 每组超过 min_group_size 条时,调 LLM 合并为一条高层次摘要
7
+ 4. 存为新碎片(fragment_type="consolidated", level=N+1)
8
+ 5. 删除原始碎片(或标记已合并)
9
+
10
+ 配置参数:
11
+ - min_group_size: 最少多少条碎片才触发合并(默认 3)
12
+ - max_age_hours: 只合并超过此年龄的碎片(给新碎片时间积累,默认 72h)
13
+ - llm_model: DashScope 模型名(默认 qwen-turbo,便宜够用)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import logging
20
+ import os
21
+ import time
22
+ import urllib.error
23
+ import urllib.request
24
+ from collections import defaultdict
25
+ from datetime import datetime, timezone
26
+ from typing import Any, Dict, List, Optional
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # DashScope API 端点
31
+ DASHSCOPE_BASE = "https://dashscope.aliyuncs.com/compatible-mode/v1"
32
+
33
+ # 默认参数
34
+ DEFAULT_MIN_GROUP_SIZE = 2 # 有重复内容就合
35
+ DEFAULT_MAX_AGE_HOURS = 72
36
+ DEFAULT_LLM_MODEL = "qwen-plus"
37
+ DEFAULT_BATCH_SIZE = 200 # 每次 consolidate 扫描的碎片数
38
+
39
+ # LLM 超时
40
+ LLM_TIMEOUT = 30
41
+
42
+ # 合并提示词
43
+ CONSOLIDATE_PROMPT = """你是一位知识提炼专家。以下是一组关于同一话题的对话片段。
44
+
45
+ 请将它们合并成一条简洁、信息完整的高层记忆条目,要求:
46
+ 1. 保留所有关键事实和结论,不丢信息
47
+ 2. 去掉重复内容
48
+ 3. 用陈述句表达,像一条知识条目
49
+ 4. 如果片段之间存在矛盾,指出矛盾但不选边
50
+ 5. 控制在 200 字以内
51
+
52
+ 对话片段:
53
+ {segments}
54
+
55
+ 合并后的知识条目:"""
56
+
57
+
58
+ def _get_api_key() -> str:
59
+ """获取 DashScope API key。"""
60
+ key = os.environ.get("OPENAI_API_KEY", "") or os.environ.get("DASHSCOPE_API_KEY", "")
61
+ # fallback: 从 config.yaml 用 yaml 解析,优先取 providers.dashscope.api_key
62
+ if not key:
63
+ try:
64
+ import yaml
65
+ config_path = os.path.expanduser("~/.hermes/config.yaml")
66
+ if os.path.isfile(config_path):
67
+ with open(config_path) as f:
68
+ cfg = yaml.safe_load(f)
69
+ if cfg:
70
+ key = (
71
+ cfg.get("providers", {}).get("dashscope", {}).get("api_key")
72
+ or cfg.get("model", {}).get("api_key")
73
+ or ""
74
+ )
75
+ key = key.strip().strip("'\"")
76
+ except Exception:
77
+ pass
78
+ return key
79
+
80
+
81
+ def _call_llm(messages: List[Dict[str, str]], model: str = DEFAULT_LLM_MODEL,
82
+ max_retries: int = 2) -> Optional[str]:
83
+ """调用 DashScope chat API 获取 LLM 回复。带重试。"""
84
+ api_key = _get_api_key()
85
+ if not api_key:
86
+ logger.warning("consolidator: no API key for LLM calls")
87
+ return None
88
+
89
+ url = f"{DASHSCOPE_BASE}/chat/completions"
90
+ payload = json.dumps({
91
+ "model": model,
92
+ "messages": messages,
93
+ "max_tokens": 512,
94
+ "temperature": 0.3,
95
+ }).encode("utf-8")
96
+
97
+ for attempt in range(1 + max_retries):
98
+ if attempt > 0:
99
+ wait = 2.0 * (2 ** (attempt - 1)) # 2s, 4s
100
+ logger.debug("consolidator: retry %d/%d after %.0fs", attempt, max_retries, wait)
101
+ time.sleep(wait)
102
+
103
+ req = urllib.request.Request(
104
+ url, data=payload,
105
+ headers={
106
+ "Authorization": f"Bearer {api_key}",
107
+ "Content-Type": "application/json",
108
+ },
109
+ method="POST",
110
+ )
111
+ try:
112
+ with urllib.request.urlopen(req, timeout=LLM_TIMEOUT) as resp:
113
+ data = json.loads(resp.read())
114
+ choices = data.get("choices", [])
115
+ if choices:
116
+ return choices[0].get("message", {}).get("content", "").strip()
117
+ # Got response but no choices — don't retry
118
+ logger.warning("consolidator: LLM returned no choices (attempt %d)", attempt + 1)
119
+ return None
120
+ except (urllib.error.URLError, json.JSONDecodeError, OSError) as e:
121
+ if attempt < max_retries:
122
+ logger.debug("consolidator: LLM attempt %d failed: %s", attempt + 1, e)
123
+ else:
124
+ logger.warning("consolidator: LLM call failed after %d attempts: %s",
125
+ attempt + 1, e)
126
+ return None
127
+
128
+
129
+ class Consolidator:
130
+ """碎片分层提炼引擎。"""
131
+
132
+ def __init__(
133
+ self,
134
+ storage: Any, # RedisStorage instance (avoid circular import)
135
+ min_group_size: int = DEFAULT_MIN_GROUP_SIZE,
136
+ max_age_hours: int = DEFAULT_MAX_AGE_HOURS,
137
+ llm_model: str = DEFAULT_LLM_MODEL,
138
+ batch_size: int = DEFAULT_BATCH_SIZE,
139
+ ):
140
+ self._storage = storage
141
+ self._min_group_size = min_group_size
142
+ self._max_age_hours = max_age_hours
143
+ self._llm_model = llm_model
144
+ self._batch_size = batch_size
145
+
146
+ def consolidate(self) -> Dict[str, Any]:
147
+ """执行一轮碎片合并。返回操作统计。"""
148
+ client = self._storage._get_client()
149
+ if not client:
150
+ return {"status": "error", "reason": "Redis not available"}
151
+
152
+ stats = {"scanned": 0, "groups_found": 0, "merged": 0, "skipped": 0, "errors": 0}
153
+
154
+ # 1. 扫描未合并的碎片
155
+ fragments = self._scan_unconsolidated(client)
156
+ stats["scanned"] = len(fragments)
157
+ if not fragments:
158
+ return stats
159
+
160
+ # 2. 按主题聚类
161
+ groups = self._cluster_by_topic(fragments)
162
+ stats["groups_found"] = len(groups)
163
+
164
+ # 3. 对每个符合条件的组执行合并
165
+ for group in groups:
166
+ if len(group) < self._min_group_size:
167
+ stats["skipped"] += len(group)
168
+ continue
169
+
170
+ result = self._merge_group(client, group)
171
+ if result:
172
+ stats["merged"] += len(group)
173
+ else:
174
+ stats["errors"] += len(group)
175
+
176
+ return stats
177
+
178
+ def _scan_unconsolidated(self, client) -> List[Dict[str, Any]]:
179
+ """扫描符合合并条件的碎片。
180
+
181
+ 条件:
182
+ - fragment_type != "consumed"(未被更高层合并吞掉的)
183
+ - 创建时间 > max_age_hours(给新碎片时间积累)
184
+ - 已合并的(consolidated)也参与扫描,实现多级提炼
185
+ """
186
+ try:
187
+ cutoff = (datetime.now(timezone.utc).timestamp() - self._max_age_hours * 3600)
188
+ cursor = 0
189
+ fragments = []
190
+
191
+ while True:
192
+ cursor, keys = client.scan(
193
+ cursor=cursor,
194
+ match="memory:frag:*",
195
+ count=self._batch_size,
196
+ )
197
+
198
+ if not keys:
199
+ if cursor == 0:
200
+ break
201
+ continue
202
+
203
+ # 用 pipeline 批量 HMGETALL,减少网络往返
204
+ pipe = client.pipeline()
205
+ for key_b in keys:
206
+ pipe.hgetall(key_b)
207
+ pipe_results = pipe.execute()
208
+
209
+ for key_b, data in zip(keys, pipe_results):
210
+ key = key_b.decode("utf-8") if isinstance(key_b, bytes) else key_b
211
+ if not data:
212
+ continue
213
+
214
+ # 解码
215
+ doc = {}
216
+ for k_b, v_b in data.items():
217
+ k = k_b.decode("utf-8") if isinstance(k_b, bytes) else k_b
218
+ v = v_b.decode("utf-8") if isinstance(v_b, bytes) else v_b
219
+ doc[k] = v
220
+
221
+ # 跳过已被更高层合并吞掉的
222
+ if doc.get("fragment_type", "") == "consumed":
223
+ continue
224
+
225
+ # 检查年龄
226
+ created_str = doc.get("created", "")
227
+ if created_str:
228
+ try:
229
+ created_ts = datetime.fromisoformat(created_str).timestamp()
230
+ if created_ts > cutoff:
231
+ continue # 太新,等下次
232
+ except (ValueError, TypeError):
233
+ pass
234
+
235
+ doc["_key"] = key
236
+ fragments.append(doc)
237
+
238
+ if cursor == 0:
239
+ break
240
+
241
+ return fragments
242
+
243
+ except Exception as e:
244
+ logger.warning("consolidator: scan error: %s", e)
245
+ return []
246
+
247
+ def _cluster_by_topic(self, fragments: List[Dict]) -> List[List[Dict]]:
248
+ """按关键词重叠做简单聚类。
249
+
250
+ 策略:
251
+ - 对每个碎片提取关键词(用 jieba)
252
+ - 关键词重叠 >= 2 的归为一组
253
+ - 贪心算法,不追求最优聚类
254
+ """
255
+ from .splitter import extract_keywords
256
+
257
+ # 提取每个碎片的关键词
258
+ frag_data = []
259
+ for f in fragments:
260
+ content = f.get("content", "")
261
+ if not content:
262
+ continue
263
+ kws = set(extract_keywords(content, max_keywords=5))
264
+ frag_data.append({"frag": f, "keywords": kws})
265
+
266
+ if not frag_data:
267
+ return []
268
+
269
+ # 贪心聚类
270
+ groups: List[List[Dict]] = []
271
+ assigned = set()
272
+
273
+ for i, data in enumerate(frag_data):
274
+ if i in assigned:
275
+ continue
276
+ group = [data["frag"]]
277
+ assigned.add(i)
278
+
279
+ for j, other in enumerate(frag_data):
280
+ if j in assigned:
281
+ continue
282
+ # 重叠 >= 2 个关键词
283
+ overlap = len(data["keywords"] & other["keywords"])
284
+ if overlap >= 2:
285
+ group.append(other["frag"])
286
+ assigned.add(j)
287
+
288
+ groups.append(group)
289
+
290
+ return groups
291
+
292
+ def _merge_group(self, client, group: List[Dict]) -> bool:
293
+ """用 LLM 合并一组碎片。"""
294
+ # 计算新层级:取组内最高 level + 1
295
+ max_level = 1
296
+ for f in group:
297
+ try:
298
+ lv = int(f.get("level", "1"))
299
+ if lv > max_level:
300
+ max_level = lv
301
+ except (ValueError, TypeError):
302
+ pass
303
+ new_level = max_level + 1
304
+
305
+ # 判断是否已有合并过的碎片
306
+ has_consolidated = any(
307
+ f.get("fragment_type") == "consolidated" or int(f.get("level", "1")) > 1
308
+ for f in group
309
+ )
310
+
311
+ # 准备片段文本
312
+ segments = []
313
+ tags_set = set()
314
+ for f in group:
315
+ content = f.get("content", "")
316
+ if content:
317
+ segments.append(f"• {content[:300]}")
318
+ tags = f.get("tags", "")
319
+ if tags:
320
+ for t in tags.split(","):
321
+ t = t.strip()
322
+ if t and not t.startswith("session:"):
323
+ tags_set.add(t)
324
+
325
+ if len(segments) < self._min_group_size:
326
+ return False
327
+
328
+ # 根据是否已有提炼过的内容选择不同 prompt
329
+ if has_consolidated:
330
+ prompt = (
331
+ "以下是一组已经提炼过的记忆条目和相关的原始对话片段。"
332
+ "请将它们进一步提炼合并成一条更精炼的高层知识条目。\n\n"
333
+ + "\n".join(segments)
334
+ + "\n\n提炼后的高层知识条目:"
335
+ )
336
+ else:
337
+ prompt = CONSOLIDATE_PROMPT.format(segments="\n".join(segments))
338
+ result = _call_llm([
339
+ {"role": "system", "content": "你是一位知识提炼专家,擅长从对话中提取核心信息。"},
340
+ {"role": "user", "content": prompt},
341
+ ], model=self._llm_model)
342
+
343
+ if not result:
344
+ return False
345
+
346
+ # 分析情绪
347
+ from .splitter import analyze_sentiment
348
+ sent_score, sent_label = analyze_sentiment(result)
349
+ now_str = datetime.now(timezone.utc).isoformat()
350
+
351
+ mapping = {
352
+ "content": result,
353
+ "tags": ",".join(sorted(tags_set)) if tags_set else "",
354
+ "category": "consolidated",
355
+ "source": "consolidator",
356
+ "created": now_str,
357
+ "fragment_type": "consolidated",
358
+ "level": str(new_level), # 多级:原始=1,首次合并=2,二次合并=3...
359
+ "sentiment_score": str(sent_score),
360
+ "sentiment_label": sent_label,
361
+ "feedback_score": "0",
362
+ }
363
+
364
+ # 存 consolidated 碎片
365
+ import hashlib
366
+ content_hash = hashlib.sha256(result.encode()).hexdigest()[:12]
367
+ consolidated_key = f"memory:frag:{content_hash}"
368
+
369
+ # 如果没有相同 key(去重检查),就存
370
+ existing = client.exists(consolidated_key)
371
+ if existing:
372
+ logger.debug("consolidator: duplicate consolidated result, skipping")
373
+ else:
374
+ client.hset(consolidated_key, mapping=mapping)
375
+
376
+ # 软删除原始碎片(标记为已消费,不硬删)
377
+ from datetime import datetime as _dt
378
+ now_iso = _dt.now(timezone.utc).isoformat()
379
+ consumed_count = 0
380
+ for f in group:
381
+ key = f.get("_key")
382
+ if key:
383
+ try:
384
+ client.hset(key, "consumed_by", consolidated_key)
385
+ client.hset(key, "consumed_at", now_iso)
386
+ client.hset(key, "fragment_type", "consumed")
387
+ consumed_count += 1
388
+ except Exception:
389
+ pass
390
+
391
+ logger.info(
392
+ "consolidator: merged %d fragments → '%s...' (marked %d as consumed)",
393
+ len(group), result[:60], consumed_count,
394
+ )
395
+ return True
keepsake/embedder.py ADDED
@@ -0,0 +1,155 @@
1
+ """Embedding 客户端 — 支持 OpenAI / DashScope / 自定义兼容端点。"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from abc import ABC, abstractmethod
9
+ from typing import Optional
10
+ from urllib.request import Request, urlopen
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # 模型 → 维度映射
16
+ # ---------------------------------------------------------------------------
17
+
18
+ _MODEL_DIMENSIONS: dict[str, int] = {
19
+ # OpenAI text-embedding-3 系列
20
+ "text-embedding-3-small": 1536,
21
+ "text-embedding-3-large": 3072,
22
+ "text-embedding-ada-002": 1536,
23
+ # DashScope
24
+ "text-embedding-v2": 1536,
25
+ "text-embedding-v3": 1024,
26
+ }
27
+
28
+ _DEFAULT_DIM = 1536
29
+
30
+
31
+ def resolve_dimension(model: str) -> int:
32
+ """根据模型名返回向量维度,未知模型返回默认值 1536。"""
33
+ model_key = model.strip().lower()
34
+ return _MODEL_DIMENSIONS.get(model_key, _DEFAULT_DIM)
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # 抽象基类
39
+ # ---------------------------------------------------------------------------
40
+
41
+
42
+ class Embedder(ABC):
43
+ @abstractmethod
44
+ def get_embedding(self, text: str) -> Optional[list[float]]:
45
+ """输入文本,返回 float 向量。"""
46
+ ...
47
+
48
+ @property
49
+ @abstractmethod
50
+ def dimension(self) -> int:
51
+ """返回当前模型输出的向量维度。"""
52
+ ...
53
+
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # OpenAI 兼容
57
+ # ---------------------------------------------------------------------------
58
+
59
+ _DEFAULT_OPENAI_URL = "https://api.openai.com/v1/embeddings"
60
+ _DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
61
+
62
+
63
+ class OpenAIEmbedder(Embedder):
64
+ """兼容 OpenAI Embedding API 的客户端。
65
+
66
+ 也兼容 DashScope 等提供 /v1/embeddings 端点的服务。
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ api_key: str = "",
72
+ base_url: str = _DEFAULT_OPENAI_URL,
73
+ model: str = _DEFAULT_OPENAI_MODEL,
74
+ ):
75
+ self._api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
76
+ self._base_url = base_url.rstrip("/")
77
+ self._model = model
78
+ self._dim = resolve_dimension(model)
79
+
80
+ @property
81
+ def dimension(self) -> int:
82
+ return self._dim
83
+
84
+ def get_embedding(self, text: str) -> Optional[list[float]]:
85
+ if not self._api_key:
86
+ logger.warning("embedder: no API key configured")
87
+ return None
88
+
89
+ payload = json.dumps({
90
+ "model": self._model,
91
+ "input": text,
92
+ }).encode("utf-8")
93
+
94
+ req = Request(
95
+ f"{self._base_url}/embeddings",
96
+ data=payload,
97
+ headers={
98
+ "Authorization": f"Bearer {self._api_key}",
99
+ "Content-Type": "application/json",
100
+ },
101
+ method="POST",
102
+ )
103
+
104
+ try:
105
+ with urlopen(req, timeout=30) as resp:
106
+ data = json.loads(resp.read())
107
+ emb = data["data"][0]["embedding"]
108
+ # 如果服务返回的维度与预期不符,更新 self._dim
109
+ if len(emb) != self._dim:
110
+ logger.info(
111
+ "embedder: model %s returned %d dims (expected %d), updating",
112
+ self._model, len(emb), self._dim,
113
+ )
114
+ self._dim = len(emb)
115
+ return emb
116
+ except Exception as e:
117
+ logger.debug("embedder: request failed: %s", e)
118
+ return None
119
+
120
+
121
+ # ---------------------------------------------------------------------------
122
+ # 工厂
123
+ # ---------------------------------------------------------------------------
124
+
125
+ _EMBEDDER_PROVIDERS: dict[str, type[Embedder]] = {
126
+ "openai": OpenAIEmbedder,
127
+ "dashscope": OpenAIEmbedder, # DashScope 也走 /v1/embeddings
128
+ }
129
+
130
+
131
+ def create_embedder(
132
+ provider: str = "",
133
+ api_key: str = "",
134
+ base_url: str = "",
135
+ model: str = "",
136
+ ) -> Embedder:
137
+ """根据配置创建 Embedder 实例。
138
+
139
+ 参数:
140
+ provider: "openai" | "dashscope" | 自定义
141
+ api_key: API 密钥
142
+ base_url: API 端点
143
+ model: 模型名
144
+ """
145
+ provider = provider or os.environ.get("FRAGMENTED_EMBEDDER", "openai").lower()
146
+ api_key = api_key or os.environ.get("OPENAI_API_KEY", "")
147
+ base_url = base_url or os.environ.get("FRAGMENTED_EMBEDDER_URL", _DEFAULT_OPENAI_URL)
148
+ model = model or os.environ.get("FRAGMENTED_EMBEDDER_MODEL", _DEFAULT_OPENAI_MODEL)
149
+
150
+ if provider == "dashscope":
151
+ base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
152
+ model = model or "text-embedding-v2"
153
+
154
+ cls = _EMBEDDER_PROVIDERS.get(provider, OpenAIEmbedder)
155
+ return cls(api_key=api_key, base_url=base_url, model=model)