sirchmunk 0.0.1.post1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. sirchmunk/api/__init__.py +1 -0
  2. sirchmunk/api/chat.py +1123 -0
  3. sirchmunk/api/components/__init__.py +0 -0
  4. sirchmunk/api/components/history_storage.py +402 -0
  5. sirchmunk/api/components/monitor_tracker.py +518 -0
  6. sirchmunk/api/components/settings_storage.py +353 -0
  7. sirchmunk/api/history.py +254 -0
  8. sirchmunk/api/knowledge.py +411 -0
  9. sirchmunk/api/main.py +120 -0
  10. sirchmunk/api/monitor.py +219 -0
  11. sirchmunk/api/run_server.py +54 -0
  12. sirchmunk/api/search.py +230 -0
  13. sirchmunk/api/settings.py +309 -0
  14. sirchmunk/api/tools.py +315 -0
  15. sirchmunk/cli/__init__.py +11 -0
  16. sirchmunk/cli/cli.py +789 -0
  17. sirchmunk/learnings/knowledge_base.py +5 -2
  18. sirchmunk/llm/prompts.py +12 -1
  19. sirchmunk/retrieve/text_retriever.py +186 -2
  20. sirchmunk/scan/file_scanner.py +2 -2
  21. sirchmunk/schema/knowledge.py +119 -35
  22. sirchmunk/search.py +384 -26
  23. sirchmunk/storage/__init__.py +2 -2
  24. sirchmunk/storage/{knowledge_manager.py → knowledge_storage.py} +265 -60
  25. sirchmunk/utils/constants.py +7 -5
  26. sirchmunk/utils/embedding_util.py +217 -0
  27. sirchmunk/utils/tokenizer_util.py +36 -1
  28. sirchmunk/version.py +1 -1
  29. {sirchmunk-0.0.1.post1.dist-info → sirchmunk-0.0.2.dist-info}/METADATA +124 -9
  30. sirchmunk-0.0.2.dist-info/RECORD +69 -0
  31. {sirchmunk-0.0.1.post1.dist-info → sirchmunk-0.0.2.dist-info}/WHEEL +1 -1
  32. sirchmunk-0.0.2.dist-info/top_level.txt +2 -0
  33. sirchmunk_mcp/__init__.py +25 -0
  34. sirchmunk_mcp/cli.py +478 -0
  35. sirchmunk_mcp/config.py +276 -0
  36. sirchmunk_mcp/server.py +355 -0
  37. sirchmunk_mcp/service.py +327 -0
  38. sirchmunk_mcp/setup.py +15 -0
  39. sirchmunk_mcp/tools.py +410 -0
  40. sirchmunk-0.0.1.post1.dist-info/RECORD +0 -45
  41. sirchmunk-0.0.1.post1.dist-info/top_level.txt +0 -1
  42. {sirchmunk-0.0.1.post1.dist-info → sirchmunk-0.0.2.dist-info}/entry_points.txt +0 -0
  43. {sirchmunk-0.0.1.post1.dist-info → sirchmunk-0.0.2.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ Manages KnowledgeCluster objects with persistence
6
6
 
7
7
  import os
8
8
  import json
9
+ import asyncio
9
10
  from typing import Dict, Any, List, Optional
10
11
  from pathlib import Path
11
12
  from datetime import datetime
@@ -20,10 +21,10 @@ from sirchmunk.schema.knowledge import (
20
21
  Lifecycle,
21
22
  AbstractionLevel
22
23
  )
23
- from ..utils.constants import DEFAULT_WORK_PATH
24
+ from ..utils.constants import DEFAULT_SIRCHMUNK_WORK_PATH
24
25
 
25
26
 
26
- class KnowledgeManager:
27
+ class KnowledgeStorage:
27
28
  """
28
29
  Manages persistent storage of KnowledgeCluster objects using DuckDB and Parquet
29
30
 
@@ -33,7 +34,7 @@ class KnowledgeManager:
33
34
  - Provides full CRUD operations with fuzzy search capabilities
34
35
  - Follows Single Responsibility Principle (SRP)
35
36
 
36
- Storage Path: {WORK_PATH}/.cache/knowledge/
37
+ Storage Path: {SIRCHMUNK_WORK_PATH}/.cache/knowledge/
37
38
  """
38
39
 
39
40
  def __init__(self, work_path: Optional[str] = None):
@@ -41,19 +42,22 @@ class KnowledgeManager:
41
42
  Initialize Knowledge Manager
42
43
 
43
44
  Args:
44
- work_path: Base work path. If None, uses WORK_PATH env variable
45
+ work_path: Base work path. If None, uses SIRCHMUNK_WORK_PATH env variable
45
46
  """
46
- # Get work path from env if not provided
47
+ # Get work path from env if not provided, and expand ~ in path
47
48
  if work_path is None:
48
- work_path = os.getenv("WORK_PATH", DEFAULT_WORK_PATH)
49
+ work_path = os.getenv("SIRCHMUNK_WORK_PATH", DEFAULT_SIRCHMUNK_WORK_PATH)
49
50
 
50
- # Create knowledge storage path
51
- self.knowledge_path = Path(work_path) / ".cache" / "knowledge"
51
+ # Create knowledge storage path (expand ~ and resolve to absolute path)
52
+ self.knowledge_path = Path(work_path).expanduser().resolve() / ".cache" / "knowledge"
52
53
  self.knowledge_path.mkdir(parents=True, exist_ok=True)
53
54
 
54
55
  # Parquet file path
55
56
  self.parquet_file = str(self.knowledge_path / "knowledge_clusters.parquet")
56
57
 
58
+ # Initialize async lock for thread-safe parquet operations
59
+ self._parquet_lock = asyncio.Lock()
60
+
57
61
  # Initialize DuckDB (in-memory for fast operations)
58
62
  self.db = DuckDBManager(db_path=None) # In-memory database
59
63
 
@@ -107,19 +111,50 @@ class KnowledgeManager:
107
111
  "version": "INTEGER",
108
112
  "related_clusters": "VARCHAR", # JSON array
109
113
  "search_results": "VARCHAR", # JSON array
114
+ "queries": "VARCHAR", # JSON array of historical queries
115
+ "embedding_vector": "FLOAT[384]", # 384-dim embedding vector
116
+ "embedding_model": "VARCHAR", # Model identifier
117
+ "embedding_timestamp": "TIMESTAMP", # Embedding computation time
118
+ "embedding_text_hash": "VARCHAR", # Hash of embedded text
110
119
  }
111
120
  self.db.create_table(self.table_name, schema, if_not_exists=True)
112
121
  logger.info(f"Created table {self.table_name}")
113
122
 
114
- def _save_to_parquet(self):
115
- """Save current knowledge clusters to parquet file"""
116
- try:
117
- # Export table to parquet
118
- self.db.export_to_parquet(self.table_name, self.parquet_file)
119
- logger.debug(f"Saved knowledge clusters to {self.parquet_file}")
120
- except Exception as e:
121
- logger.error(f"Failed to save to parquet: {e}")
122
- raise
123
+ async def _save_to_parquet(self):
124
+ """
125
+ Save current knowledge clusters to parquet file with thread-safe atomic write.
126
+ Uses async lock to prevent concurrent writes and temp file + rename for atomicity.
127
+ """
128
+ async with self._parquet_lock:
129
+ temp_file = None
130
+ try:
131
+ # Generate temporary file path with timestamp for uniqueness
132
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
133
+ temp_file = f"{self.parquet_file}.tmp_{timestamp}"
134
+
135
+ # Export table to temporary parquet file
136
+ self.db.export_to_parquet(self.table_name, temp_file)
137
+
138
+ # Verify temporary file was created successfully
139
+ if not Path(temp_file).exists():
140
+ raise IOError(f"Temporary file not created: {temp_file}")
141
+
142
+ # Atomically replace the target file with the temporary file
143
+ # os.replace() is atomic on both Unix and Windows
144
+ os.replace(temp_file, self.parquet_file)
145
+
146
+ logger.debug(f"Atomically saved knowledge clusters to {self.parquet_file}")
147
+
148
+ except Exception as e:
149
+ logger.error(f"Failed to save to parquet: {e}")
150
+ # Clean up temporary file if it exists
151
+ if temp_file and Path(temp_file).exists():
152
+ try:
153
+ Path(temp_file).unlink()
154
+ logger.debug(f"Cleaned up temporary file: {temp_file}")
155
+ except Exception as cleanup_error:
156
+ logger.warning(f"Failed to clean up temp file {temp_file}: {cleanup_error}")
157
+ raise
123
158
 
124
159
  def _cluster_to_row(self, cluster: KnowledgeCluster) -> Dict[str, Any]:
125
160
  """Convert KnowledgeCluster to database row"""
@@ -155,34 +190,32 @@ class KnowledgeManager:
155
190
  "version": cluster.version,
156
191
  "related_clusters": json.dumps([rc.to_dict() for rc in cluster.related_clusters]),
157
192
  "search_results": json.dumps(cluster.search_results) if cluster.search_results else None,
193
+ "queries": json.dumps(cluster.queries) if cluster.queries else None,
158
194
  }
159
195
 
160
196
  def _row_to_cluster(self, row: tuple) -> KnowledgeCluster:
161
- """Convert database row to KnowledgeCluster"""
162
- # Unpack row (order matches schema). Older tables may not include search_results.
163
- if len(row) == 19:
164
- (
165
- id, name, description, content, scripts, resources, evidences, patterns,
166
- constraints, confidence, abstraction_level, landmark_potential, hotness,
167
- lifecycle, create_time, last_modified, version, related_clusters, search_results
168
- ) = row
169
- elif len(row) == 18:
170
- (
171
- id, name, description, content, scripts, resources, evidences, patterns,
172
- constraints, confidence, abstraction_level, landmark_potential, hotness,
173
- lifecycle, create_time, last_modified, version, related_clusters
174
- ) = row
175
- search_results = None
176
- elif len(row) == 17:
177
- (
178
- id, name, description, content, scripts, resources, evidences, patterns,
179
- constraints, confidence, abstraction_level, landmark_potential, hotness,
180
- lifecycle, create_time, last_modified, version
181
- ) = row
182
- related_clusters = None
183
- search_results = None
184
- else:
185
- raise ValueError(f"Unexpected knowledge_clusters row length: {len(row)}")
197
+ """
198
+ Convert database row to KnowledgeCluster.
199
+
200
+ Expected row structure (24 columns):
201
+ id, name, description, content, scripts, resources, evidences, patterns,
202
+ constraints, confidence, abstraction_level, landmark_potential, hotness,
203
+ lifecycle, create_time, last_modified, version, related_clusters, search_results, queries,
204
+ embedding_vector, embedding_model, embedding_timestamp, embedding_text_hash
205
+ """
206
+ if len(row) != 24:
207
+ raise ValueError(
208
+ f"Expected 24 columns in knowledge_clusters row, got {len(row)}. "
209
+ f"Please ensure the table schema is up to date."
210
+ )
211
+
212
+ # Unpack row (embedding fields are ignored as they're not part of KnowledgeCluster schema)
213
+ (
214
+ id, name, description, content, scripts, resources, evidences, patterns,
215
+ constraints, confidence, abstraction_level, landmark_potential, hotness,
216
+ lifecycle, create_time, last_modified, version, related_clusters, search_results, queries,
217
+ _embedding_vector, _embedding_model, _embedding_timestamp, _embedding_text_hash
218
+ ) = row
186
219
 
187
220
  # Parse JSON fields
188
221
  try:
@@ -204,13 +237,22 @@ class KnowledgeManager:
204
237
  if evidences:
205
238
  evidences_data = json.loads(evidences)
206
239
  for ev_dict in evidences_data:
240
+ # Parse extracted_at field (handle both string and datetime types)
241
+ extracted_at_raw = ev_dict.get("extracted_at")
242
+ extracted_at_parsed = None
243
+ if extracted_at_raw:
244
+ if isinstance(extracted_at_raw, str):
245
+ extracted_at_parsed = datetime.fromisoformat(extracted_at_raw)
246
+ elif isinstance(extracted_at_raw, datetime):
247
+ extracted_at_parsed = extracted_at_raw
248
+
207
249
  evidences_parsed.append(EvidenceUnit(
208
250
  doc_id=ev_dict["doc_id"],
209
251
  file_or_url=Path(ev_dict["file_or_url"]),
210
252
  summary=ev_dict["summary"],
211
253
  is_found=ev_dict["is_found"],
212
254
  snippets=ev_dict["snippets"],
213
- extracted_at=datetime.fromisoformat(ev_dict["extracted_at"]),
255
+ extracted_at=extracted_at_parsed or datetime.now(),
214
256
  conflict_group=ev_dict.get("conflict_group")
215
257
  ))
216
258
 
@@ -233,6 +275,26 @@ class KnowledgeManager:
233
275
  if search_results:
234
276
  search_results_parsed = json.loads(search_results)
235
277
 
278
+ # Parse queries
279
+ queries_parsed = []
280
+ if queries:
281
+ queries_parsed = json.loads(queries)
282
+
283
+ # Parse datetime fields (handle both string and datetime types)
284
+ create_time_parsed = None
285
+ if create_time:
286
+ if isinstance(create_time, str):
287
+ create_time_parsed = datetime.fromisoformat(create_time)
288
+ elif isinstance(create_time, datetime):
289
+ create_time_parsed = create_time
290
+
291
+ last_modified_parsed = None
292
+ if last_modified:
293
+ if isinstance(last_modified, str):
294
+ last_modified_parsed = datetime.fromisoformat(last_modified)
295
+ elif isinstance(last_modified, datetime):
296
+ last_modified_parsed = last_modified
297
+
236
298
  return KnowledgeCluster(
237
299
  id=id,
238
300
  name=name,
@@ -248,11 +310,12 @@ class KnowledgeManager:
248
310
  landmark_potential=landmark_potential,
249
311
  hotness=hotness,
250
312
  lifecycle=Lifecycle[lifecycle],
251
- create_time=datetime.fromisoformat(create_time) if create_time else None,
252
- last_modified=datetime.fromisoformat(last_modified) if last_modified else None,
313
+ create_time=create_time_parsed,
314
+ last_modified=last_modified_parsed,
253
315
  version=version,
254
316
  related_clusters=related_clusters_parsed,
255
317
  search_results=search_results_parsed,
318
+ queries=queries_parsed,
256
319
  )
257
320
 
258
321
  async def get(self, cluster_id: str) -> Optional[KnowledgeCluster]:
@@ -308,8 +371,8 @@ class KnowledgeManager:
308
371
  row = self._cluster_to_row(cluster)
309
372
  self.db.insert_data(self.table_name, row)
310
373
 
311
- # Save to parquet
312
- self._save_to_parquet()
374
+ # Save to parquet with atomic write
375
+ await self._save_to_parquet()
313
376
 
314
377
  logger.info(f"Inserted cluster: {cluster.id}")
315
378
  return True
@@ -351,8 +414,8 @@ class KnowledgeManager:
351
414
  where_params=[cluster.id]
352
415
  )
353
416
 
354
- # Save to parquet
355
- self._save_to_parquet()
417
+ # Save to parquet with atomic write
418
+ await self._save_to_parquet()
356
419
 
357
420
  logger.info(f"Updated cluster: {cluster.id} (version {cluster.version})")
358
421
  return True
@@ -381,8 +444,8 @@ class KnowledgeManager:
381
444
  # Delete from database
382
445
  self.db.delete_data(self.table_name, "id = ?", [cluster_id])
383
446
 
384
- # Save to parquet
385
- self._save_to_parquet()
447
+ # Save to parquet with atomic write
448
+ await self._save_to_parquet()
386
449
 
387
450
  logger.info(f"Removed cluster: {cluster_id}")
388
451
  return True
@@ -666,9 +729,7 @@ class KnowledgeManager:
666
729
  Dictionary with statistics
667
730
  """
668
731
  try:
669
- stats = self.db.analyze_table(self.table_name)
670
-
671
- # Add custom stats
732
+ # Get basic table count
672
733
  total_count = self.db.get_table_count(self.table_name)
673
734
 
674
735
  # Count by lifecycle
@@ -686,12 +747,24 @@ class KnowledgeManager:
686
747
  )
687
748
  avg_confidence = avg_confidence_row[0] if avg_confidence_row and avg_confidence_row[0] else 0
688
749
 
689
- stats["custom_stats"] = {
690
- "total_clusters": total_count,
691
- "lifecycle_distribution": lifecycle_counts,
692
- "average_confidence": round(avg_confidence, 4) if avg_confidence else None,
693
- "parquet_file": self.parquet_file,
694
- "parquet_exists": Path(self.parquet_file).exists(),
750
+ # Count clusters with embeddings
751
+ embedding_count_row = self.db.fetch_one(
752
+ f"SELECT COUNT(*) FROM {self.table_name} WHERE embedding_vector IS NOT NULL"
753
+ )
754
+ embedding_count = embedding_count_row[0] if embedding_count_row else 0
755
+
756
+ # Build stats dictionary
757
+ stats = {
758
+ "table_name": self.table_name,
759
+ "row_count": total_count,
760
+ "custom_stats": {
761
+ "total_clusters": total_count,
762
+ "clusters_with_embeddings": embedding_count,
763
+ "lifecycle_distribution": lifecycle_counts,
764
+ "average_confidence": round(avg_confidence, 4) if avg_confidence else None,
765
+ "parquet_file": self.parquet_file,
766
+ "parquet_exists": Path(self.parquet_file).exists(),
767
+ }
695
768
  }
696
769
 
697
770
  return stats
@@ -700,6 +773,138 @@ class KnowledgeManager:
700
773
  logger.error(f"Failed to get stats: {e}")
701
774
  return {}
702
775
 
776
+ @staticmethod
777
+ def combine_cluster_fields(queries: List[str]) -> str:
778
+ """
779
+ Combine cluster queries into single text for embedding.
780
+
781
+ Args:
782
+ queries: List of historical user queries
783
+
784
+ Returns:
785
+ Combined text string
786
+ """
787
+ if not queries:
788
+ return "unknown"
789
+
790
+ # Join all queries with newline separator
791
+ return "\n".join(queries)
792
+
793
+ async def store_embedding(
794
+ self,
795
+ cluster_id: str,
796
+ embedding_vector: List[float],
797
+ embedding_model: str,
798
+ embedding_text_hash: str
799
+ ) -> bool:
800
+ """
801
+ Store embedding vector for a knowledge cluster.
802
+
803
+ Args:
804
+ cluster_id: Cluster ID
805
+ embedding_vector: 384-dim embedding vector
806
+ embedding_model: Model identifier used for embedding
807
+ embedding_text_hash: Hash of the text that was embedded
808
+
809
+ Returns:
810
+ True if successful, False otherwise
811
+ """
812
+ try:
813
+ # Verify embedding dimension
814
+ if len(embedding_vector) != 384:
815
+ logger.error(
816
+ f"Invalid embedding dimension: expected 384, got {len(embedding_vector)}"
817
+ )
818
+ return False
819
+
820
+ # Update embedding fields in database
821
+ self.db.execute(
822
+ f"""
823
+ UPDATE {self.table_name}
824
+ SET
825
+ embedding_vector = ?::FLOAT[384],
826
+ embedding_model = ?,
827
+ embedding_timestamp = CURRENT_TIMESTAMP,
828
+ embedding_text_hash = ?
829
+ WHERE id = ?
830
+ """,
831
+ [embedding_vector, embedding_model, embedding_text_hash, cluster_id]
832
+ )
833
+
834
+ # Save to parquet with atomic write
835
+ await self._save_to_parquet()
836
+
837
+ logger.debug(f"Stored embedding for cluster {cluster_id}")
838
+ return True
839
+
840
+ except Exception as e:
841
+ logger.error(f"Failed to store embedding for cluster {cluster_id}: {e}")
842
+ return False
843
+
844
+ async def search_similar_clusters(
845
+ self,
846
+ query_embedding: List[float],
847
+ top_k: int = 3,
848
+ similarity_threshold: float = 0.82
849
+ ) -> List[Dict[str, Any]]:
850
+ """
851
+ Search for similar clusters using vector similarity.
852
+
853
+ Args:
854
+ query_embedding: 384-dim query embedding vector
855
+ top_k: Maximum number of results to return
856
+ similarity_threshold: Minimum cosine similarity threshold
857
+
858
+ Returns:
859
+ List of similar clusters with metadata and similarity scores
860
+ """
861
+ try:
862
+ # Verify query embedding dimension
863
+ if len(query_embedding) != 384:
864
+ logger.error(
865
+ f"Invalid query embedding dimension: expected 384, got {len(query_embedding)}"
866
+ )
867
+ return []
868
+
869
+ # DuckDB cosine similarity query
870
+ query = f"""
871
+ SELECT
872
+ id, name, description, confidence, hotness,
873
+ list_cosine_similarity(embedding_vector, ?::FLOAT[384]) AS similarity
874
+ FROM {self.table_name}
875
+ WHERE embedding_vector IS NOT NULL
876
+ ORDER BY similarity DESC
877
+ LIMIT ?
878
+ """
879
+
880
+ results = self.db.fetch_all(query, [query_embedding, top_k])
881
+
882
+ # Filter by similarity threshold
883
+ filtered_results = []
884
+ for row in results:
885
+ similarity = row[5]
886
+ if similarity >= similarity_threshold:
887
+ filtered_results.append({
888
+ "id": row[0],
889
+ "name": row[1],
890
+ "description": row[2],
891
+ "confidence": row[3],
892
+ "hotness": row[4],
893
+ "similarity": similarity
894
+ })
895
+
896
+ if filtered_results:
897
+ logger.debug(
898
+ f"Found {len(filtered_results)} similar clusters "
899
+ f"(threshold: {similarity_threshold})"
900
+ )
901
+
902
+ return filtered_results
903
+
904
+ except Exception as e:
905
+ logger.error(f"Failed to search similar clusters: {e}")
906
+ return []
907
+
703
908
  def close(self):
704
909
  """Close database connection"""
705
910
  if self.db:
@@ -6,10 +6,12 @@ from pathlib import Path
6
6
  GREP_CONCURRENT_LIMIT = int(os.getenv("GREP_CONCURRENT_LIMIT", "5"))
7
7
 
8
8
  # LLM Configuration
9
- LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
9
+ LLM_BASE_URL = os.getenv("LLM_BASE_URL", "https://api.openai.com/v1")
10
10
  LLM_API_KEY = os.getenv("LLM_API_KEY", "")
11
- LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "qwen3-max")
11
+ LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "gpt-5.2")
12
12
 
13
- # Search Configuration
14
- DEFAULT_WORK_PATH = os.path.expanduser("~/sirchmunk")
15
- WORK_PATH = os.getenv("WORK_PATH", DEFAULT_WORK_PATH)
13
+ # Sirchmunk Working Directory Configuration
14
+ DEFAULT_SIRCHMUNK_WORK_PATH = os.path.expanduser("~/.sirchmunk")
15
+ # Expand ~ in environment variable if set
16
+ _env_work_path = os.getenv("SIRCHMUNK_WORK_PATH")
17
+ SIRCHMUNK_WORK_PATH = os.path.expanduser(_env_work_path) if _env_work_path else DEFAULT_SIRCHMUNK_WORK_PATH
@@ -0,0 +1,217 @@
1
+ # Copyright (c) ModelScope Contributors. All rights reserved.
2
+ """
3
+ Local Embedding Utility
4
+ Provides embedding computation using SentenceTransformer models loaded from ModelScope
5
+ """
6
+
7
+ import asyncio
8
+ import hashlib
9
+ import warnings
10
+ from typing import List, Optional, Dict, Any
11
+
12
+ import torch
13
+ import numpy as np
14
+ from loguru import logger
15
+
16
+
17
+ class EmbeddingUtil:
18
+ """
19
+ Embedding utility using SentenceTransformer models.
20
+ Loads models from ModelScope for embedding computation.
21
+ """
22
+
23
+ DEFAULT_MODEL_ID = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
24
+ EMBEDDING_DIM = 384
25
+
26
+ def __init__(
27
+ self,
28
+ model_id: str = DEFAULT_MODEL_ID,
29
+ device: Optional[str] = None,
30
+ cache_dir: Optional[str] = None
31
+ ):
32
+ """
33
+ Initialize local embedding client.
34
+
35
+ Args:
36
+ model_id: ModelScope model identifier
37
+ device: Device for inference ("cuda", "cpu", or None for auto-detection)
38
+ cache_dir: Optional cache directory for model files
39
+ """
40
+ self.model_id = model_id
41
+
42
+ # Auto-detect device if not specified
43
+ if device is None:
44
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ else:
46
+ self.device = device
47
+
48
+ # Load model with suppressed warnings
49
+ model_dir = self._load_model(model_id, cache_dir)
50
+
51
+ # Import SentenceTransformer and load model with warnings suppressed
52
+ # The 'embeddings.position_ids' warning is expected and can be ignored
53
+ from sentence_transformers import SentenceTransformer
54
+ with warnings.catch_warnings():
55
+ warnings.filterwarnings("ignore", message=".*position_ids.*")
56
+ warnings.filterwarnings("ignore", category=FutureWarning)
57
+ self.model = SentenceTransformer(model_dir, device=self.device)
58
+
59
+ # Warm up model with dummy inference
60
+ self._warmup()
61
+
62
+ logger.debug(
63
+ f"Loaded embedding model: {model_id} "
64
+ f"(device={self.device}, dim={self.dimension})"
65
+ )
66
+
67
+ @staticmethod
68
+ def _load_model(model_id: str, cache_dir: Optional[str] = None) -> str:
69
+ """
70
+ Load the embedding model from ModelScope or Hugging Face.
71
+
72
+ Args:
73
+ model_id: Model identifier
74
+ cache_dir: Optional cache directory for model files
75
+
76
+ Returns:
77
+ Path to downloaded model directory
78
+ """
79
+ try:
80
+ from modelscope import snapshot_download
81
+
82
+ model_dir = snapshot_download(
83
+ model_id=model_id,
84
+ cache_dir=cache_dir,
85
+ ignore_patterns=[
86
+ "openvino/*", "onnx/*", "pytorch_model.bin",
87
+ "rust_model.ot", "tf_model.h5"
88
+ ]
89
+ )
90
+ logger.debug(f"Model loaded successfully: {model_dir}")
91
+ return model_dir
92
+
93
+ except Exception as e:
94
+ logger.error(f"Failed to load model {model_id}: {e}")
95
+ raise RuntimeError(
96
+ f"Model loading failed. Please check network or model_id. Error: {e}"
97
+ )
98
+
99
+ def _warmup(self):
100
+ """Warm up model with dummy inference to avoid first-call latency"""
101
+ try:
102
+ dummy_text = ["warmup text"]
103
+ _ = self.model.encode(dummy_text, show_progress_bar=False)
104
+ logger.debug("Model warmup completed")
105
+ except Exception as e:
106
+ logger.warning(f"Model warmup failed: {e}")
107
+
108
+ async def embed(self, texts: List[str]) -> List[List[float]]:
109
+ """
110
+ Compute embeddings for batch texts using local model.
111
+
112
+ Args:
113
+ texts: List of input texts
114
+
115
+ Returns:
116
+ List of embedding vectors (each of dimension 384)
117
+ """
118
+ if not texts:
119
+ return []
120
+
121
+ # SentenceTransformer.encode is CPU-bound, run in thread pool
122
+ loop = asyncio.get_event_loop()
123
+ embeddings = await loop.run_in_executor(
124
+ None, # Use default ThreadPoolExecutor
125
+ self._encode_sync,
126
+ texts
127
+ )
128
+
129
+ return embeddings.tolist()
130
+
131
+ def _encode_sync(self, texts: List[str]) -> np.ndarray:
132
+ """
133
+ Synchronous encoding wrapper for thread pool execution.
134
+
135
+ Args:
136
+ texts: List of texts to encode
137
+
138
+ Returns:
139
+ Numpy array of embeddings
140
+ """
141
+ return self.model.encode(
142
+ texts,
143
+ normalize_embeddings=True, # Enable L2 normalization for cosine similarity
144
+ show_progress_bar=False,
145
+ convert_to_numpy=True
146
+ )
147
+
148
+ @property
149
+ def dimension(self) -> int:
150
+ """Return embedding dimension (384 for MiniLM-L12-v2)"""
151
+ return self.model.get_sentence_embedding_dimension()
152
+
153
+ def get_model_info(self) -> Dict[str, Any]:
154
+ """
155
+ Return model metadata.
156
+
157
+ Returns:
158
+ Dictionary with model information
159
+ """
160
+ return {
161
+ "model_id": self.model_id,
162
+ "dimension": self.dimension,
163
+ "device": self.device,
164
+ "max_seq_length": self.model.max_seq_length,
165
+ }
166
+
167
+
168
+ @classmethod
169
+ def preload_model(
170
+ cls,
171
+ cache_dir: Optional[str] = None,
172
+ model_id: str = None,
173
+ ) -> str:
174
+ """
175
+ Pre-download the embedding model without initializing.
176
+
177
+ This is useful during initialization to download the model
178
+ without loading it into memory.
179
+
180
+ Args:
181
+ cache_dir: Cache directory for model files
182
+ model_id: Model identifier (uses default if None)
183
+
184
+ Returns:
185
+ Path to downloaded model directory
186
+ """
187
+ model_id = model_id or cls.DEFAULT_MODEL_ID
188
+
189
+ return cls._load_model(model_id, cache_dir)
190
+
191
+
192
+ def compute_text_hash(text: str) -> str:
193
+ """
194
+ Compute SHA256 hash for text content.
195
+
196
+ Args:
197
+ text: Input text
198
+
199
+ Returns:
200
+ Hex string of hash (first 16 characters)
201
+ """
202
+ return hashlib.sha256(text.encode('utf-8')).hexdigest()[:16]
203
+
204
+
205
+ if __name__ == '__main__':
206
+
207
+ # Example usage
208
+ import asyncio
209
+
210
+ async def main():
211
+ embed_util = EmbeddingUtil()
212
+ texts = ["Hello world", "ModelScope embedding"]
213
+ embeddings = await embed_util.embed(texts)
214
+ for text, emb in zip(texts, embeddings):
215
+ print(f"Text: {text}\nEmbedding: {emb}\n")
216
+
217
+ asyncio.run(main())