screenforge 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. cli/__init__.py +0 -0
  2. cli/_version.py +1 -0
  3. cli/dispatch.py +266 -0
  4. cli/doctor.py +487 -0
  5. cli/modes/__init__.py +0 -0
  6. cli/modes/action.py +262 -0
  7. cli/modes/default.py +248 -0
  8. cli/modes/demo.py +162 -0
  9. cli/modes/dry_run.py +237 -0
  10. cli/modes/init.py +133 -0
  11. cli/modes/plan.py +148 -0
  12. cli/modes/workflow.py +354 -0
  13. cli/parser.py +305 -0
  14. cli/reporter.py +207 -0
  15. cli/session.py +146 -0
  16. cli/shared.py +427 -0
  17. cli/shorthand.py +90 -0
  18. cli/tool_protocol_handlers.py +446 -0
  19. common/__init__.py +0 -0
  20. common/adapters/__init__.py +21 -0
  21. common/adapters/android_adapter.py +273 -0
  22. common/adapters/base_adapter.py +24 -0
  23. common/adapters/ios_adapter.py +278 -0
  24. common/adapters/web_adapter.py +271 -0
  25. common/ai.py +277 -0
  26. common/ai_autonomous.py +273 -0
  27. common/ai_heal.py +222 -0
  28. common/cache/__init__.py +15 -0
  29. common/cache/cache_hash.py +57 -0
  30. common/cache/cache_manager.py +300 -0
  31. common/cache/cache_stats.py +133 -0
  32. common/cache/cache_storage.py +79 -0
  33. common/cache/embedding_loader.py +150 -0
  34. common/capabilities.py +121 -0
  35. common/case_memory.py +327 -0
  36. common/error_codes.py +61 -0
  37. common/exceptions.py +18 -0
  38. common/executor.py +1504 -0
  39. common/failure_diagnosis.py +138 -0
  40. common/history_manager.py +75 -0
  41. common/logs.py +168 -0
  42. common/mcp_server.py +467 -0
  43. common/preflight.py +496 -0
  44. common/progress.py +37 -0
  45. common/run_reporter.py +415 -0
  46. common/run_resume.py +149 -0
  47. common/runtime_modes.py +35 -0
  48. common/tool_protocol.py +196 -0
  49. common/visual_fallback.py +71 -0
  50. common/workflow_schema.py +150 -0
  51. config/__init__.py +0 -0
  52. config/config.py +167 -0
  53. config/env_loader.py +76 -0
  54. screenforge-0.4.0.dist-info/METADATA +43 -0
  55. screenforge-0.4.0.dist-info/RECORD +64 -0
  56. screenforge-0.4.0.dist-info/WHEEL +5 -0
  57. screenforge-0.4.0.dist-info/entry_points.txt +2 -0
  58. screenforge-0.4.0.dist-info/licenses/LICENSE +21 -0
  59. screenforge-0.4.0.dist-info/top_level.txt +4 -0
  60. utils/__init__.py +0 -0
  61. utils/screenshot_annotator.py +60 -0
  62. utils/utils_ios.py +195 -0
  63. utils/utils_web.py +304 -0
  64. utils/utils_xml.py +218 -0
@@ -0,0 +1,300 @@
1
+ import os
2
+
3
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
4
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
5
+ os.environ["TRANSFORMERS_VERBOSITY"] = "error"
6
+
7
+ from datetime import datetime, timezone
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ import numpy as np
11
+ from numpy.linalg import norm
12
+
13
+ from common.logs import log
14
+ from config.config import CACHE_EXACT_MATCH_THRESHOLD, CACHE_SIMILARITY_THRESHOLD
15
+
16
+ from .cache_hash import compute_instruction_hash, compute_ui_hash
17
+ from .cache_stats import CacheStats
18
+ from .cache_storage import cleanup_expired_entries, load_cache, save_cache
19
+ from .embedding_loader import EmbeddingModelLoader
20
+
21
+
22
+ class CacheManager:
23
+ def __init__(
24
+ self,
25
+ cache_dir: str = ".cache",
26
+ enabled: bool = False,
27
+ ttl_days: int = 365,
28
+ max_size_mb: int = 100,
29
+ ):
30
+ self._cache_dir = cache_dir
31
+ self._enabled = enabled
32
+ self._ttl_seconds = ttl_days * 24 * 60 * 60
33
+ self._stats = CacheStats(cache_dir)
34
+ self._model_loader = EmbeddingModelLoader()
35
+
36
+ @property
37
+ def enabled(self) -> bool:
38
+ return self._enabled
39
+
40
+ @enabled.setter
41
+ def enabled(self, value: bool) -> None:
42
+ self._enabled = value
43
+
44
+ def _get_model(self):
45
+ return self._model_loader.load()
46
+
47
+ def _get_embedding(self, text: str) -> Optional[list]:
48
+ # Returns None when the ML stack (sentence_transformers) is unavailable,
49
+ # so callers can fall back to exact-key-only behavior instead of crashing.
50
+ model = self._get_model()
51
+ if model is None:
52
+ return None
53
+ return model.encode(text).tolist()
54
+
55
+ def _cosine_similarity(self, vec1: list, vec2: list) -> float:
56
+ v1, v2 = np.array(vec1), np.array(vec2)
57
+ norm1, norm2 = norm(v1), norm(v2)
58
+ if norm1 == 0 or norm2 == 0:
59
+ return 0.0
60
+ return float(np.dot(v1, v2) / (norm1 * norm2))
61
+
62
+ def _cosine_similarity_batch(
63
+ self, query_vec: list, candidate_vecs: List[list]
64
+ ) -> List[tuple]:
65
+ if not candidate_vecs:
66
+ return []
67
+ q = np.array(query_vec)
68
+ q_norm = norm(q)
69
+ if q_norm == 0:
70
+ return []
71
+
72
+ candidates = np.stack(candidate_vecs)
73
+ candidate_norms = norm(candidates, axis=1)
74
+ scores = np.zeros(len(candidate_vecs))
75
+
76
+ valid_indices = candidate_norms > 0
77
+ if np.any(valid_indices):
78
+ scores[valid_indices] = np.dot(candidates[valid_indices], q) / (
79
+ candidate_norms[valid_indices] * q_norm
80
+ )
81
+
82
+ return list(enumerate(scores.tolist()))
83
+
84
+ def _hybrid_search(
85
+ self,
86
+ instruction: str,
87
+ target_ui_hash: Optional[str],
88
+ cache_type: str,
89
+ platform: str,
90
+ threshold: float,
91
+ exact_key: str,
92
+ ) -> Optional[Dict]:
93
+ if not self._enabled:
94
+ return None
95
+
96
+ try:
97
+ cache_data = load_cache(self._cache_dir)
98
+ cache_data = cleanup_expired_entries(cache_data, self._ttl_seconds)
99
+ entries = cache_data.get("entries", {})
100
+ if not entries:
101
+ self._stats.increment_miss()
102
+ return None
103
+
104
+ if exact_key in entries:
105
+ matched_entry = entries[exact_key]
106
+ if matched_entry.get("platform") == platform:
107
+ saved_time = matched_entry.get("metadata", {}).get("llm_latency", 0.0)
108
+ log.info(f"[Exact Cache Hit] {cache_type} matched ({platform})")
109
+ if saved_time > 0:
110
+ log.info(f"[Cache] Saved {saved_time:.2f}s of LLM inference time")
111
+ matched_entry["metadata"]["last_accessed"] = datetime.now(timezone.utc).isoformat()
112
+ matched_entry["metadata"]["access_count"] = matched_entry["metadata"].get("access_count", 0) + 1
113
+ save_cache(self._cache_dir, cache_data)
114
+ self._stats.increment_hit()
115
+ return matched_entry.get("decision")
116
+
117
+ current_vector = self._get_embedding(instruction)
118
+ if current_vector is None:
119
+ # ML stack unavailable: exact-key already missed above, and we
120
+ # can't do semantic similarity without embeddings. Treat as miss.
121
+ self._stats.increment_miss()
122
+ return None
123
+ candidate_entries = []
124
+ candidate_vectors = []
125
+ for key, entry in entries.items():
126
+ if entry.get("type") != cache_type:
127
+ continue
128
+ if entry.get("platform") != platform:
129
+ continue
130
+ if target_ui_hash and entry.get("ui_hash") != target_ui_hash:
131
+ continue
132
+ past_vector = entry.get("instruction_vector")
133
+ if not past_vector:
134
+ continue
135
+ candidate_entries.append(entry)
136
+ candidate_vectors.append(past_vector)
137
+ if not candidate_vectors:
138
+ log.debug("[Cache Miss] No candidate vectors to compare")
139
+ self._stats.increment_miss()
140
+ return None
141
+ scores = self._cosine_similarity_batch(current_vector, candidate_vectors)
142
+ if not scores:
143
+ log.debug("[Cache Miss] No similarity scores computed")
144
+ self._stats.increment_miss()
145
+ return None
146
+ best_idx, best_score = max(scores, key=lambda s: s[1])
147
+ if best_score >= threshold:
148
+ matched_entry = candidate_entries[best_idx]
149
+ past_inst = matched_entry.get("instruction")
150
+ saved_time = matched_entry.get("metadata", {}).get("llm_latency", 0.0)
151
+ log.info(f"[Semantic Cache Hit] {cache_type} matched, similarity: {best_score:.2%} ({platform})")
152
+ log.info(f"[Cache] Query: '{instruction}' | Matched: '{past_inst}'")
153
+ if saved_time > 0:
154
+ log.info(f"[Cache] Saved {saved_time:.2f}s of LLM inference time")
155
+ matched_entry["metadata"]["last_accessed"] = datetime.now(timezone.utc).isoformat()
156
+ matched_entry["metadata"]["access_count"] = matched_entry["metadata"].get("access_count", 0) + 1
157
+ save_cache(self._cache_dir, cache_data)
158
+ self._stats.increment_hit()
159
+ return matched_entry.get("decision")
160
+ log.debug(f"[Cache Miss] Best similarity {best_score:.2%} below threshold {threshold:.2%}")
161
+ self._stats.increment_miss()
162
+ return None
163
+ except Exception as e:
164
+ log.error(f"[Cache Error] Retrieval failed: {e}")
165
+ return None
166
+
167
+ def _set_hybrid(
168
+ self,
169
+ instruction: str,
170
+ decision: Dict,
171
+ ui_hash: Optional[str],
172
+ cache_type: str,
173
+ platform: str,
174
+ exact_key: str,
175
+ llm_latency: float = 0.0,
176
+ ) -> bool:
177
+ if not self._enabled:
178
+ return False
179
+ try:
180
+ cache_data = load_cache(self._cache_dir)
181
+ entries = cache_data.setdefault("entries", {})
182
+ current_vector = self._get_embedding(instruction)
183
+ keys_to_delete = []
184
+ for k, v in entries.items():
185
+ if v.get("type") != cache_type:
186
+ continue
187
+ if v.get("platform") != platform:
188
+ continue
189
+ is_same_decision = v.get("decision") == decision
190
+ is_same_instruction = v.get("instruction") == instruction
191
+ if cache_type == "L1-Action":
192
+ if is_same_instruction and is_same_decision and v.get("ui_hash") != ui_hash:
193
+ keys_to_delete.append(k)
194
+ elif cache_type == "L2-SimpleQA":
195
+ past_vector = v.get("instruction_vector")
196
+ if current_vector is not None and past_vector and is_same_decision:
197
+ sim = self._cosine_similarity(current_vector, past_vector)
198
+ if sim > CACHE_EXACT_MATCH_THRESHOLD:
199
+ v["metadata"]["last_accessed"] = datetime.now(timezone.utc).isoformat()
200
+ v["metadata"]["access_count"] = v["metadata"].get("access_count", 0) + 1
201
+ save_cache(self._cache_dir, cache_data)
202
+ return True
203
+ for k in keys_to_delete:
204
+ del entries[k]
205
+ entry = {
206
+ "type": cache_type,
207
+ "platform": platform,
208
+ "instruction": instruction,
209
+ "instruction_vector": current_vector,
210
+ "decision": decision,
211
+ "metadata": {
212
+ "created_at": datetime.now(timezone.utc).isoformat(),
213
+ "last_accessed": datetime.now(timezone.utc).isoformat(),
214
+ "access_count": 1,
215
+ "ttl_seconds": self._ttl_seconds,
216
+ "llm_latency": round(llm_latency, 2),
217
+ },
218
+ }
219
+ if ui_hash is not None:
220
+ entry["ui_hash"] = ui_hash
221
+ entries[exact_key] = entry
222
+ save_cache(self._cache_dir, cache_data)
223
+ return True
224
+ except Exception as e:
225
+ log.error(f"[Cache Error] Write failed: {e}")
226
+ return False
227
+
228
+ def get(self, instruction: str, ui_json: Dict[str, Any], platform: str) -> Optional[Dict[str, Any]]:
229
+ try:
230
+ ui_hash = compute_ui_hash(ui_json)
231
+ inst_hash = compute_instruction_hash(instruction)
232
+ exact_key = f"L1_{platform}_{inst_hash}_{ui_hash}"
233
+ return self._hybrid_search(
234
+ instruction, ui_hash, "L1-Action", platform, CACHE_SIMILARITY_THRESHOLD, exact_key
235
+ )
236
+ except Exception as e:
237
+ log.error(f"[Cache Error] get() failed: {e}")
238
+ return None
239
+
240
+ def set(
241
+ self,
242
+ instruction: str,
243
+ ui_json: Dict[str, Any],
244
+ decision: Dict[str, Any],
245
+ platform: str,
246
+ llm_latency: float = 0.0,
247
+ ) -> bool:
248
+ try:
249
+ ui_hash = compute_ui_hash(ui_json)
250
+ inst_hash = compute_instruction_hash(instruction)
251
+ exact_key = f"L1_{platform}_{inst_hash}_{ui_hash}"
252
+ return self._set_hybrid(
253
+ instruction, decision, ui_hash, "L1-Action", platform, exact_key, llm_latency
254
+ )
255
+ except Exception as e:
256
+ log.error(f"[Cache Error] set() failed: {e}")
257
+ return False
258
+
259
+ def get_chat_simple(self, instruction: str, platform: str) -> Optional[Dict[str, Any]]:
260
+ try:
261
+ inst_hash = compute_instruction_hash(instruction)
262
+ exact_key = f"L2_{platform}_{inst_hash}"
263
+ return self._hybrid_search(
264
+ instruction, None, "L2-SimpleQA", platform, 0.88, exact_key
265
+ )
266
+ except Exception as e:
267
+ log.error(f"[Cache Error] get_chat_simple() failed: {e}")
268
+ return None
269
+
270
+ def set_chat_simple(
271
+ self,
272
+ instruction: str,
273
+ decision: Dict[str, Any],
274
+ platform: str,
275
+ llm_latency: float = 0.0,
276
+ ) -> bool:
277
+ try:
278
+ inst_hash = compute_instruction_hash(instruction)
279
+ exact_key = f"L2_{platform}_{inst_hash}"
280
+ return self._set_hybrid(
281
+ instruction, decision, None, "L2-SimpleQA", platform, exact_key, llm_latency
282
+ )
283
+ except Exception as e:
284
+ log.error(f"[Cache Error] set_chat_simple() failed: {e}")
285
+ return False
286
+
287
+ def clear(self) -> bool:
288
+ try:
289
+ save_cache(self._cache_dir, {"version": "1.2", "entries": {}})
290
+ return True
291
+ except Exception as e:
292
+ log.error(f"[Cache Error] clear() failed: {e}")
293
+ return False
294
+
295
+ def get_stats(self) -> Dict[str, Any]:
296
+ try:
297
+ return self._stats.to_dict()
298
+ except Exception as e:
299
+ log.error(f"[Cache Error] get_stats() failed: {e}")
300
+ return {}
@@ -0,0 +1,133 @@
1
+ import json
2
+ import os
3
+ from datetime import datetime, timezone
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Optional
6
+
7
+ from filelock import FileLock
8
+
9
+
10
+ class CacheStats:
11
+ def __init__(self, cache_dir: str = ".cache"):
12
+ self._cache_dir = cache_dir
13
+ self._total_queries = 0
14
+ self._cache_hits = 0
15
+ self._cache_misses = 0
16
+ self._total_api_calls_saved = 0
17
+ self._first_cache_date = None
18
+ self._last_cache_date = None
19
+ self._load_stats()
20
+
21
+ @property
22
+ def total_queries(self) -> int:
23
+ return self._total_queries
24
+
25
+ @property
26
+ def cache_hits(self) -> int:
27
+ return self._cache_hits
28
+
29
+ @property
30
+ def cache_misses(self) -> int:
31
+ return self._cache_misses
32
+
33
+ @property
34
+ def hit_rate(self) -> float:
35
+ if self._total_queries == 0:
36
+ return 0.0
37
+ return self._cache_hits / self._total_queries
38
+
39
+ @property
40
+ def total_api_calls_saved(self) -> int:
41
+ return self._total_api_calls_saved
42
+
43
+ @property
44
+ def first_cache_date(self) -> Optional[str]:
45
+ return self._first_cache_date
46
+
47
+ @property
48
+ def last_cache_date(self) -> Optional[str]:
49
+ return self._last_cache_date
50
+
51
+ def increment_query(self) -> None:
52
+ self._total_queries += 1
53
+ self._update_last_cache_date()
54
+ self._save_stats()
55
+
56
+ def increment_hit(self) -> None:
57
+ self._total_queries += 1
58
+ self._cache_hits += 1
59
+ self._total_api_calls_saved += 1
60
+ self._update_last_cache_date()
61
+ self._save_stats()
62
+
63
+ def increment_miss(self) -> None:
64
+ self._total_queries += 1
65
+ self._cache_misses += 1
66
+ self._update_last_cache_date()
67
+ self._save_stats()
68
+
69
+ def _update_first_cache_date(self) -> None:
70
+ if not self._first_cache_date:
71
+ self._first_cache_date = datetime.now(timezone.utc).isoformat()
72
+
73
+ def _update_last_cache_date(self) -> None:
74
+ self._last_cache_date = datetime.now(timezone.utc).isoformat()
75
+ self._update_first_cache_date()
76
+
77
+ def _load_stats(self) -> None:
78
+ stats_path = Path(self._cache_dir) / "cache_stats.json"
79
+ os.makedirs(self._cache_dir, exist_ok=True)
80
+ lock = FileLock(f"{stats_path}.lock")
81
+
82
+ try:
83
+ with lock:
84
+ if not stats_path.exists():
85
+ return
86
+ with open(stats_path, "r", encoding="utf-8") as f:
87
+ data = json.load(f)
88
+ self._total_queries = data.get("total_queries", 0)
89
+ self._cache_hits = data.get("cache_hits", 0)
90
+ self._cache_misses = data.get("cache_misses", 0)
91
+ self._total_api_calls_saved = data.get("total_api_calls_saved", 0)
92
+ self._first_cache_date = data.get("first_cache_date")
93
+ self._last_cache_date = data.get("last_cache_date")
94
+ except (json.JSONDecodeError, IOError):
95
+ pass
96
+
97
+ def _save_stats(self) -> None:
98
+ stats_path = Path(self._cache_dir) / "cache_stats.json"
99
+ temp_path = stats_path.with_suffix(".tmp")
100
+ os.makedirs(self._cache_dir, exist_ok=True)
101
+ lock = FileLock(f"{stats_path}.lock")
102
+
103
+ data = {
104
+ "total_queries": self._total_queries,
105
+ "cache_hits": self._cache_hits,
106
+ "cache_misses": self._cache_misses,
107
+ "hit_rate": (self._cache_hits / self._total_queries) if self._total_queries > 0 else 0.0,
108
+ "total_api_calls_saved": self._total_api_calls_saved,
109
+ "first_cache_date": self._first_cache_date,
110
+ "last_cache_date": self._last_cache_date
111
+ }
112
+ try:
113
+ with lock:
114
+ with open(temp_path, "w", encoding="utf-8") as f:
115
+ json.dump(data, f, ensure_ascii=False, indent=2)
116
+ f.flush()
117
+ os.fsync(f.fileno())
118
+ os.replace(temp_path, stats_path)
119
+ except IOError:
120
+ if temp_path.exists():
121
+ temp_path.unlink()
122
+ raise
123
+
124
+ def to_dict(self) -> Dict[str, Any]:
125
+ return {
126
+ "total_queries": self._total_queries,
127
+ "cache_hits": self._cache_hits,
128
+ "cache_misses": self._cache_misses,
129
+ "hit_rate": (self._cache_hits / self._total_queries) if self._total_queries > 0 else 0.0,
130
+ "total_api_calls_saved": self._total_api_calls_saved,
131
+ "first_cache_date": self._first_cache_date,
132
+ "last_cache_date": self._last_cache_date
133
+ }
@@ -0,0 +1,79 @@
1
+ import json
2
+ import os
3
+ import time
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Any, Dict
7
+
8
+ from filelock import FileLock
9
+
10
+
11
+ def get_cache_filename() -> str:
12
+ # 缓存的长久保留与过期剔除,交由底部的 cleanup_expired_entries 通过 TTL 控制
13
+ return "ai_decisions_v2.json"
14
+
15
+
16
+ def _get_lock_path(cache_path: Path) -> str:
17
+ return f"{cache_path}.lock"
18
+
19
+
20
+ def load_cache(cache_dir: str) -> Dict[str, Any]:
21
+ cache_path = Path(cache_dir) / get_cache_filename()
22
+ if not cache_path.exists():
23
+ return {"version": "1.1", "entries": {}}
24
+
25
+ lock = FileLock(_get_lock_path(cache_path))
26
+ try:
27
+ with lock:
28
+ with open(cache_path, "r", encoding="utf-8") as f:
29
+ data = json.load(f)
30
+ return data if "entries" in data else {"version": "1.1", "entries": {}}
31
+ except (json.JSONDecodeError, IOError):
32
+ return {"version": "1.1", "entries": {}}
33
+
34
+
35
+ def save_cache(cache_dir: str, data: Dict[str, Any]) -> None:
36
+ cache_path = Path(cache_dir) / get_cache_filename()
37
+ temp_path = cache_path.with_suffix(".tmp")
38
+ os.makedirs(cache_dir, exist_ok=True)
39
+
40
+ lock = FileLock(_get_lock_path(cache_path))
41
+ try:
42
+ with lock:
43
+ with open(temp_path, "w", encoding="utf-8") as f:
44
+ json.dump(data, f, ensure_ascii=False, indent=2)
45
+ f.flush()
46
+ os.fsync(f.fileno())
47
+ os.replace(temp_path, cache_path)
48
+ except IOError:
49
+ if temp_path.exists():
50
+ temp_path.unlink()
51
+ raise
52
+
53
+
54
+ def cleanup_expired_entries(
55
+ data: Dict[str, Any], default_ttl_seconds: float
56
+ ) -> Dict[str, Any]:
57
+ if "entries" not in data:
58
+ return data
59
+
60
+ now = time.time()
61
+ cleaned_entries = {}
62
+
63
+ for key, entry in data["entries"].items():
64
+ metadata = entry.get("metadata", {})
65
+ created_at = metadata.get("created_at")
66
+ if not created_at:
67
+ continue
68
+
69
+ try:
70
+ # 标准化 UTC 时间解析
71
+ created_time = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
72
+ entry_ttl = metadata.get("ttl_seconds", default_ttl_seconds)
73
+
74
+ if now - created_time.timestamp() <= entry_ttl:
75
+ cleaned_entries[key] = entry
76
+ except ValueError:
77
+ continue
78
+
79
+ return {"version": data.get("version", "1.1"), "entries": cleaned_entries}
@@ -0,0 +1,150 @@
1
+ from pathlib import Path
2
+ from typing import Any, Optional
3
+
4
+ from common.logs import log
5
+
6
+ # NOTE: `sentence_transformers` (and its torch/transformers stack, ~2GB) is an
7
+ # OPTIONAL dependency — installed via `pip install screenforge[ml]`, not the core
8
+ # requirements. It is imported lazily inside load(), NOT at module scope, so that
9
+ # importing common.cache / common.ai on a clean (core-only) install does not crash
10
+ # with ModuleNotFoundError. The semantic (vector) cache degrades gracefully when
11
+ # the package is absent; the exact-key (hash) cache keeps working.
12
+
13
+
14
+ class EmbeddingModelLoader:
15
+ """负责句子向量模型的加载和缓存管理"""
16
+
17
+ def __init__(
18
+ self,
19
+ model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
20
+ hf_cache_dir: Optional[Path] = None,
21
+ disable_ssl_verify: bool = True,
22
+ ):
23
+ self.model_name = model_name
24
+ self.hf_cache_dir = hf_cache_dir or self._default_cache_dir(model_name)
25
+ self.disable_ssl_verify = disable_ssl_verify
26
+ self._model = None
27
+ self._original_requests_init = None
28
+ self._original_httpx_init = None
29
+
30
+ @staticmethod
31
+ def _default_cache_dir(model_name: str) -> Path:
32
+ """计算 HuggingFace 模型的默认缓存目录"""
33
+ return (
34
+ Path.home()
35
+ / ".cache"
36
+ / "huggingface"
37
+ / "hub"
38
+ / f"models--sentence-transformers--{model_name}"
39
+ )
40
+
41
+ def _configure_network(self):
42
+ """配置网络请求库(仅在模型加载期间生效,强制跳过 SSL)"""
43
+ if not self.disable_ssl_verify:
44
+ return
45
+
46
+ import requests
47
+ import urllib3
48
+
49
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
50
+
51
+ # 保存对原函数的引用
52
+ self._original_requests_init = requests.Session.__init__
53
+ original_req_init = self._original_requests_init # 通过局部变量闭包捕获
54
+
55
+ def safe_session_init(sess_self, *args, **kwargs):
56
+ # sess_self 是 requests.Session 实例
57
+ original_req_init(sess_self, *args, **kwargs)
58
+ sess_self.verify = False
59
+
60
+ requests.Session.__init__ = safe_session_init
61
+
62
+ try:
63
+ import httpx
64
+
65
+ # 保存对原函数的引用
66
+ self._original_httpx_init = httpx.Client.__init__
67
+ original_httpx_init = self._original_httpx_init # 通过局部变量闭包捕获
68
+
69
+ def safe_httpx_init(client_self, *args, **kwargs):
70
+ # client_self 是 httpx.Client 实例
71
+ kwargs["verify"] = False
72
+ original_httpx_init(client_self, *args, **kwargs)
73
+
74
+ httpx.Client.__init__ = safe_httpx_init
75
+ except ImportError:
76
+ self._original_httpx_init = None
77
+
78
+ def _restore_network(self):
79
+ """恢复网络请求库的原始配置,不污染全局环境"""
80
+ if not self.disable_ssl_verify:
81
+ return
82
+
83
+ import requests
84
+
85
+ if self._original_requests_init:
86
+ requests.Session.__init__ = self._original_requests_init
87
+
88
+ if self._original_httpx_init is not None:
89
+ import httpx
90
+
91
+ httpx.Client.__init__ = self._original_httpx_init
92
+
93
+ def _cleanup_corrupted_cache(self) -> bool:
94
+ """清理损坏的模型缓存"""
95
+ if not self.hf_cache_dir.exists():
96
+ return False
97
+
98
+ import shutil
99
+
100
+ log.warning("[System] Corrupted model cache detected, cleaning up...")
101
+ shutil.rmtree(self.hf_cache_dir)
102
+ log.info("[System] Corrupted cache cleaned")
103
+ return True
104
+
105
+ def _should_cleanup_cache(self, error_msg: str) -> bool:
106
+ """判断是否需要清理缓存"""
107
+ error_indicators = ["Can't load the model", "pytorch_model.bin", "safetensors"]
108
+ return any(indicator in error_msg for indicator in error_indicators)
109
+
110
+ def load(self) -> Optional[Any]:
111
+ """加载模型(带缓存、代理兼容和错误处理机制)。
112
+
113
+ ML 依赖缺失时返回 None(优雅降级,不抛异常),调用方据此跳过向量检索。
114
+ """
115
+ if self._model is not None:
116
+ return self._model
117
+
118
+ try:
119
+ from sentence_transformers import SentenceTransformer
120
+ except ImportError:
121
+ log.warning(
122
+ "[System] Semantic cache disabled — 'sentence_transformers' not "
123
+ "installed. Install with: pip install screenforge[ml] (exact-key "
124
+ "cache still works without it)."
125
+ )
126
+ return None
127
+
128
+ log.info("[System] Initializing local semantic cache engine...")
129
+
130
+ if not self.hf_cache_dir.exists():
131
+ log.warning("[System] First run — downloading embedding model (~100MB)...")
132
+ else:
133
+ log.info("[System] Loading cached embedding model...")
134
+
135
+ self._configure_network()
136
+
137
+ try:
138
+ try:
139
+ self._model = SentenceTransformer(self.model_name)
140
+ except Exception as e:
141
+ if self._should_cleanup_cache(str(e)):
142
+ self._cleanup_corrupted_cache()
143
+ self._model = SentenceTransformer(self.model_name)
144
+ else:
145
+ raise e
146
+
147
+ log.info("[System] Semantic cache model loaded and ready")
148
+ return self._model
149
+ finally:
150
+ self._restore_network()