isage-middleware 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of isage-middleware might be problematic. Click here for more details.
- isage_middleware-0.1.0.dist-info/METADATA +424 -0
- isage_middleware-0.1.0.dist-info/RECORD +191 -0
- isage_middleware-0.1.0.dist-info/WHEEL +5 -0
- isage_middleware-0.1.0.dist-info/top_level.txt +1 -0
- sage/__init__.py +2 -0
- sage/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/__init__.py +83 -0
- sage/middleware/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/api/__init__.py +22 -0
- sage/middleware/api/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/api/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/api/__pycache__/graph_api.cpython-311.opt-2.pyc +0 -0
- sage/middleware/api/__pycache__/graph_api.cpython-311.pyc +0 -0
- sage/middleware/api/__pycache__/kv_api.cpython-311.opt-2.pyc +0 -0
- sage/middleware/api/__pycache__/kv_api.cpython-311.pyc +0 -0
- sage/middleware/api/__pycache__/memory_api.cpython-311.opt-2.pyc +0 -0
- sage/middleware/api/__pycache__/memory_api.cpython-311.pyc +0 -0
- sage/middleware/api/__pycache__/vdb_api.cpython-311.opt-2.pyc +0 -0
- sage/middleware/api/__pycache__/vdb_api.cpython-311.pyc +0 -0
- sage/middleware/api/graph_api.py +74 -0
- sage/middleware/api/kv_api.py +45 -0
- sage/middleware/api/memory_api.py +64 -0
- sage/middleware/api/vdb_api.py +60 -0
- sage/middleware/enterprise/__init__.py +75 -0
- sage/middleware/enterprise/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/enterprise/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/enterprise/sage_db/__init__.py +132 -0
- sage/middleware/enterprise/sage_db/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/enterprise/sage_db/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/enterprise/sage_db/__pycache__/sage_db.cpython-311.opt-2.pyc +0 -0
- sage/middleware/enterprise/sage_db/__pycache__/sage_db.cpython-311.pyc +0 -0
- sage/middleware/enterprise/sage_db/python/__init__.py +7 -0
- sage/middleware/enterprise/sage_db/python/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/enterprise/sage_db/python/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/enterprise/sage_db/python/__pycache__/sage_db.cpython-311.opt-2.pyc +0 -0
- sage/middleware/enterprise/sage_db/python/__pycache__/sage_db.cpython-311.pyc +0 -0
- sage/middleware/enterprise/sage_db/python/sage_db.py +44 -0
- sage/middleware/enterprise/sage_db/sage_db.py +395 -0
- sage/middleware/enterprise/sage_db/tests/__pycache__/test_python.cpython-311.opt-2.pyc +0 -0
- sage/middleware/enterprise/sage_db/tests/__pycache__/test_python.cpython-311.pyc +0 -0
- sage/middleware/enterprise/sage_db/tests/test_python.py +144 -0
- sage/middleware/examples/__pycache__/api_usage_tutorial.cpython-311.opt-2.pyc +0 -0
- sage/middleware/examples/__pycache__/api_usage_tutorial.cpython-311.pyc +0 -0
- sage/middleware/examples/__pycache__/dag_microservices_demo.cpython-311.opt-2.pyc +0 -0
- sage/middleware/examples/__pycache__/dag_microservices_demo.cpython-311.pyc +0 -0
- sage/middleware/examples/__pycache__/microservices_demo.cpython-311.opt-2.pyc +0 -0
- sage/middleware/examples/__pycache__/microservices_demo.cpython-311.pyc +0 -0
- sage/middleware/examples/__pycache__/microservices_integration_demo.cpython-311.opt-2.pyc +0 -0
- sage/middleware/examples/__pycache__/microservices_integration_demo.cpython-311.pyc +0 -0
- sage/middleware/examples/__pycache__/microservices_registration_demo.cpython-311.opt-2.pyc +0 -0
- sage/middleware/examples/__pycache__/microservices_registration_demo.cpython-311.pyc +0 -0
- sage/middleware/examples/api_usage_tutorial.py +339 -0
- sage/middleware/examples/dag_microservices_demo.py +220 -0
- sage/middleware/examples/microservices_demo.py +0 -0
- sage/middleware/examples/microservices_integration_demo.py +373 -0
- sage/middleware/examples/microservices_registration_demo.py +144 -0
- sage/middleware/py.typed +2 -0
- sage/middleware/services/graph/__init__.py +8 -0
- sage/middleware/services/graph/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/graph/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/services/graph/__pycache__/graph_index.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/graph/__pycache__/graph_index.cpython-311.pyc +0 -0
- sage/middleware/services/graph/__pycache__/graph_service.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/graph/__pycache__/graph_service.cpython-311.pyc +0 -0
- sage/middleware/services/graph/examples/__pycache__/graph_demo.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/graph/examples/__pycache__/graph_demo.cpython-311.pyc +0 -0
- sage/middleware/services/graph/examples/graph_demo.py +177 -0
- sage/middleware/services/graph/graph_index.py +194 -0
- sage/middleware/services/graph/graph_service.py +541 -0
- sage/middleware/services/graph/search_engine/__init__.py +0 -0
- sage/middleware/services/graph/search_engine/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/graph/search_engine/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/services/graph/search_engine/__pycache__/base_graph_index.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/graph/search_engine/__pycache__/base_graph_index.cpython-311.pyc +0 -0
- sage/middleware/services/graph/search_engine/base_graph_index.py +0 -0
- sage/middleware/services/kv/__init__.py +8 -0
- sage/middleware/services/kv/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/kv/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/services/kv/__pycache__/kv_service.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/kv/__pycache__/kv_service.cpython-311.pyc +0 -0
- sage/middleware/services/kv/examples/__pycache__/kv_demo.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/kv/examples/__pycache__/kv_demo.cpython-311.pyc +0 -0
- sage/middleware/services/kv/examples/kv_demo.py +213 -0
- sage/middleware/services/kv/kv_service.py +306 -0
- sage/middleware/services/kv/search_engine/__init__.py +0 -0
- sage/middleware/services/kv/search_engine/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/kv/search_engine/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/services/kv/search_engine/__pycache__/base_kv_index.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/kv/search_engine/__pycache__/base_kv_index.cpython-311.pyc +0 -0
- sage/middleware/services/kv/search_engine/__pycache__/bm25s_index.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/kv/search_engine/__pycache__/bm25s_index.cpython-311.pyc +0 -0
- sage/middleware/services/kv/search_engine/base_kv_index.py +75 -0
- sage/middleware/services/kv/search_engine/bm25s_index.py +238 -0
- sage/middleware/services/memory/__init__.py +12 -0
- sage/middleware/services/memory/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/services/memory/__pycache__/memory_service.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/__pycache__/memory_service.cpython-311.pyc +0 -0
- sage/middleware/services/memory/examples/__pycache__/dag_microservices_demo.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/examples/__pycache__/dag_microservices_demo.cpython-311.pyc +0 -0
- sage/middleware/services/memory/examples/__pycache__/memory_demo.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/examples/__pycache__/memory_demo.cpython-311.pyc +0 -0
- sage/middleware/services/memory/examples/dag_microservices_demo.py +220 -0
- sage/middleware/services/memory/examples/memory_demo.py +490 -0
- sage/middleware/services/memory/memory_collection/__pycache__/base_collection.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/memory_collection/__pycache__/base_collection.cpython-311.pyc +0 -0
- sage/middleware/services/memory/memory_collection/__pycache__/graph_collection.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/memory_collection/__pycache__/graph_collection.cpython-311.pyc +0 -0
- sage/middleware/services/memory/memory_collection/__pycache__/kv_collection.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/memory_collection/__pycache__/kv_collection.cpython-311.pyc +0 -0
- sage/middleware/services/memory/memory_collection/__pycache__/vdb_collection.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/memory_collection/__pycache__/vdb_collection.cpython-311.pyc +0 -0
- sage/middleware/services/memory/memory_collection/base_collection.py +0 -0
- sage/middleware/services/memory/memory_collection/graph_collection.py +0 -0
- sage/middleware/services/memory/memory_collection/kv_collection.py +0 -0
- sage/middleware/services/memory/memory_collection/vdb_collection.py +0 -0
- sage/middleware/services/memory/memory_service.py +474 -0
- sage/middleware/services/memory/utils/__init__.py +0 -0
- sage/middleware/services/memory/utils/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/services/memory/utils/__pycache__/path_utils.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/memory/utils/__pycache__/path_utils.cpython-311.pyc +0 -0
- sage/middleware/services/memory/utils/path_utils.py +0 -0
- sage/middleware/services/vdb/__init__.py +8 -0
- sage/middleware/services/vdb/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/vdb/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/services/vdb/__pycache__/vdb_service.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/vdb/__pycache__/vdb_service.cpython-311.pyc +0 -0
- sage/middleware/services/vdb/examples/__pycache__/vdb_demo.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/vdb/examples/__pycache__/vdb_demo.cpython-311.pyc +0 -0
- sage/middleware/services/vdb/examples/vdb_demo.py +447 -0
- sage/middleware/services/vdb/search_engine/__init__.py +0 -0
- sage/middleware/services/vdb/search_engine/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/vdb/search_engine/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/services/vdb/search_engine/__pycache__/base_vdb_index.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/vdb/search_engine/__pycache__/base_vdb_index.cpython-311.pyc +0 -0
- sage/middleware/services/vdb/search_engine/__pycache__/faiss_index.cpython-311.opt-2.pyc +0 -0
- sage/middleware/services/vdb/search_engine/__pycache__/faiss_index.cpython-311.pyc +0 -0
- sage/middleware/services/vdb/search_engine/base_vdb_index.py +58 -0
- sage/middleware/services/vdb/search_engine/faiss_index.py +461 -0
- sage/middleware/services/vdb/vdb_service.py +433 -0
- sage/middleware/utils/__init__.py +5 -0
- sage/middleware/utils/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__init__.py +35 -0
- sage/middleware/utils/embedding/__pycache__/__init__.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/__init__.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/_cohere.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/_cohere.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/bedrock.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/bedrock.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/embedding_api.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/embedding_api.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/embedding_model.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/embedding_model.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/hf.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/hf.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/instructor.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/instructor.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/jina.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/jina.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/lollms.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/lollms.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/mockembedder.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/mockembedder.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/nvidia_openai.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/nvidia_openai.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/ollama.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/ollama.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/openai.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/openai.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/siliconcloud.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/siliconcloud.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/zhipu.cpython-311.opt-2.pyc +0 -0
- sage/middleware/utils/embedding/__pycache__/zhipu.cpython-311.pyc +0 -0
- sage/middleware/utils/embedding/_cohere.py +68 -0
- sage/middleware/utils/embedding/bedrock.py +174 -0
- sage/middleware/utils/embedding/embedding_api.py +12 -0
- sage/middleware/utils/embedding/embedding_model.py +150 -0
- sage/middleware/utils/embedding/hf.py +90 -0
- sage/middleware/utils/embedding/instructor.py +10 -0
- sage/middleware/utils/embedding/jina.py +115 -0
- sage/middleware/utils/embedding/lollms.py +100 -0
- sage/middleware/utils/embedding/mockembedder.py +46 -0
- sage/middleware/utils/embedding/nvidia_openai.py +97 -0
- sage/middleware/utils/embedding/ollama.py +97 -0
- sage/middleware/utils/embedding/openai.py +112 -0
- sage/middleware/utils/embedding/siliconcloud.py +133 -0
- sage/middleware/utils/embedding/zhipu.py +85 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
import pipmaster as pm # Pipmaster for dynamic library install
|
|
6
|
+
|
|
7
|
+
# Dependencies should be installed via requirements.txt
|
|
8
|
+
# aioboto3 and tenacity are required for this module
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import aioboto3
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ImportError(
|
|
14
|
+
"aioboto3 package is required for AWS Bedrock embedding functionality. "
|
|
15
|
+
"Please install it via: pip install aioboto3"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
from tenacity import (
|
|
20
|
+
retry,
|
|
21
|
+
stop_after_attempt,
|
|
22
|
+
wait_exponential,
|
|
23
|
+
retry_if_exception_type,
|
|
24
|
+
)
|
|
25
|
+
except ImportError:
|
|
26
|
+
raise ImportError(
|
|
27
|
+
"tenacity package is required for AWS Bedrock embedding functionality. "
|
|
28
|
+
"Please install it via: pip install tenacity"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BedrockError(Exception):
|
|
36
|
+
"""Generic error for issues related to Amazon Bedrock"""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
async def bedrock_embed(
|
|
41
|
+
text: str,
|
|
42
|
+
model: str = "amazon.titan-embed-text-v2:0",
|
|
43
|
+
aws_access_key_id=None,
|
|
44
|
+
aws_secret_access_key=None,
|
|
45
|
+
aws_session_token=None,
|
|
46
|
+
) -> list:
|
|
47
|
+
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
|
48
|
+
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
|
49
|
+
)
|
|
50
|
+
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
|
51
|
+
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
|
52
|
+
)
|
|
53
|
+
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
|
54
|
+
"AWS_SESSION_TOKEN", aws_session_token
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
session = aioboto3.Session()
|
|
58
|
+
async with session.client("bedrock-sage.kernels.runtime") as bedrock_async_client:
|
|
59
|
+
model_provider = model.split(".")[0]
|
|
60
|
+
|
|
61
|
+
if model_provider == "amazon":
|
|
62
|
+
if "v2" in model:
|
|
63
|
+
body = json.dumps({
|
|
64
|
+
"inputText": text,
|
|
65
|
+
"embeddingTypes": ["float"],
|
|
66
|
+
})
|
|
67
|
+
elif "v1" in model:
|
|
68
|
+
body = json.dumps({"inputText": text})
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(f"Model {model} is not supported!")
|
|
71
|
+
|
|
72
|
+
response = await bedrock_async_client.invoke_model(
|
|
73
|
+
modelId=model,
|
|
74
|
+
body=body,
|
|
75
|
+
accept="application/json",
|
|
76
|
+
contentType="application/json",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
response_body = await response.get("body").json()
|
|
80
|
+
return response_body["embedding"]
|
|
81
|
+
|
|
82
|
+
elif model_provider == "cohere":
|
|
83
|
+
body = json.dumps({
|
|
84
|
+
"texts": [text],
|
|
85
|
+
"input_type": "search_document",
|
|
86
|
+
"truncate": "NONE",
|
|
87
|
+
})
|
|
88
|
+
|
|
89
|
+
response = await bedrock_async_client.invoke_model(
|
|
90
|
+
model=model,
|
|
91
|
+
body=body,
|
|
92
|
+
accept="application/json",
|
|
93
|
+
contentType="application/json",
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
response_body = json.loads(response.get("body").read())
|
|
97
|
+
return response_body["embeddings"][0]
|
|
98
|
+
|
|
99
|
+
else:
|
|
100
|
+
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
import os
|
|
104
|
+
import json
|
|
105
|
+
import boto3
|
|
106
|
+
|
|
107
|
+
def bedrock_embed_sync(
|
|
108
|
+
text: str,
|
|
109
|
+
model: str = "amazon.titan-embed-text-v2:0",
|
|
110
|
+
aws_access_key_id=None,
|
|
111
|
+
aws_secret_access_key=None,
|
|
112
|
+
aws_session_token=None,
|
|
113
|
+
) -> list[float]:
|
|
114
|
+
"""
|
|
115
|
+
同步版本:使用 AWS Bedrock 生成 embedding。
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
text: 输入文本
|
|
119
|
+
model: 模型 ID,例如 "amazon.titan-embed-text-v2:0"
|
|
120
|
+
aws_access_key_id / secret / session_token: 可选 AWS 认证信息
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
list[float]: embedding 向量
|
|
124
|
+
"""
|
|
125
|
+
# 设置 AWS 环境变量(优先从参数取)
|
|
126
|
+
if aws_access_key_id:
|
|
127
|
+
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
|
128
|
+
if aws_secret_access_key:
|
|
129
|
+
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
|
130
|
+
if aws_session_token:
|
|
131
|
+
os.environ["AWS_SESSION_TOKEN"] = aws_session_token
|
|
132
|
+
|
|
133
|
+
bedrock_client = boto3.client("bedrock-sage.kernels.runtime")
|
|
134
|
+
|
|
135
|
+
model_provider = model.split(".")[0]
|
|
136
|
+
|
|
137
|
+
if model_provider == "amazon":
|
|
138
|
+
if "v2" in model:
|
|
139
|
+
body = json.dumps({
|
|
140
|
+
"inputText": text,
|
|
141
|
+
"embeddingTypes": ["float"],
|
|
142
|
+
})
|
|
143
|
+
elif "v1" in model:
|
|
144
|
+
body = json.dumps({"inputText": text})
|
|
145
|
+
else:
|
|
146
|
+
raise ValueError(f"Model {model} is not supported!")
|
|
147
|
+
|
|
148
|
+
response = bedrock_client.invoke_model(
|
|
149
|
+
modelId=model,
|
|
150
|
+
body=body,
|
|
151
|
+
accept="application/json",
|
|
152
|
+
contentType="application/json",
|
|
153
|
+
)
|
|
154
|
+
response_body = json.loads(response["body"].read())
|
|
155
|
+
return response_body["embedding"]
|
|
156
|
+
|
|
157
|
+
elif model_provider == "cohere":
|
|
158
|
+
body = json.dumps({
|
|
159
|
+
"texts": [text],
|
|
160
|
+
"input_type": "search_document",
|
|
161
|
+
"truncate": "NONE",
|
|
162
|
+
})
|
|
163
|
+
|
|
164
|
+
response = bedrock_client.invoke_model(
|
|
165
|
+
modelId=model,
|
|
166
|
+
body=body,
|
|
167
|
+
accept="application/json",
|
|
168
|
+
contentType="application/json",
|
|
169
|
+
)
|
|
170
|
+
response_body = json.loads(response["body"].read())
|
|
171
|
+
return response_body["embeddings"][0]
|
|
172
|
+
|
|
173
|
+
else:
|
|
174
|
+
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from sage.middleware.utils.embedding.embedding_model import EmbeddingModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def apply_embedding_model(name: str = "default",**kwargs) -> EmbeddingModel:
|
|
5
|
+
"""
|
|
6
|
+
usage 参见sage/api/model/operator_test.py
|
|
7
|
+
while name(method) = "hf", please set the param:model;
|
|
8
|
+
while name(method) = "openai",if you need call other APIs which are compatible with openai,set the params:base_url,api_key,model;
|
|
9
|
+
while name(method) = "jina/siliconcloud/cohere",please set the params:api_key,model;
|
|
10
|
+
Example:operator_test.py
|
|
11
|
+
"""
|
|
12
|
+
return EmbeddingModel(method=name,**kwargs)
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
4
|
+
|
|
5
|
+
# 添加项目根目录到Python路径
|
|
6
|
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")))
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
from dotenv import load_dotenv
|
|
11
|
+
|
|
12
|
+
load_dotenv()
|
|
13
|
+
|
|
14
|
+
from sage.middleware.utils.embedding import hf, ollama, siliconcloud, openai, bedrock, zhipu, mockembedder # , instructor
|
|
15
|
+
from sage.middleware.utils.embedding import _cohere, nvidia_openai, lollms, jina
|
|
16
|
+
from transformers import AutoModel, AutoTokenizer
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class EmbeddingModel:
|
|
20
|
+
# def __init__(self, method: str = "openai", model: str = "mistral-embed",
|
|
21
|
+
# base_url: str = None, api_key: str = None):
|
|
22
|
+
def __init__(self, method: str = "openai", **kwargs):
|
|
23
|
+
"""
|
|
24
|
+
初始化 embedding table
|
|
25
|
+
:param method: 指定使用的 embedding 方法名称,例如 "openai" 或 "cohere" 或“hf"等
|
|
26
|
+
"""
|
|
27
|
+
self.init_method = method
|
|
28
|
+
self.dim = None
|
|
29
|
+
if method == "default":
|
|
30
|
+
method = "hf"
|
|
31
|
+
kwargs["model"] = "sentence-transformers/all-MiniLM-L6-v2"
|
|
32
|
+
|
|
33
|
+
if method == "mockembedder":
|
|
34
|
+
kwargs["model"] = "mockembedder" # 确保 model 参数存在
|
|
35
|
+
if "fixed_dim" not in kwargs:
|
|
36
|
+
kwargs["fixed_dim"] = 128 # 默认维度
|
|
37
|
+
|
|
38
|
+
self.set_dim(kwargs["model"] )
|
|
39
|
+
self.method = method
|
|
40
|
+
|
|
41
|
+
# self.kwargs = {}
|
|
42
|
+
self.kwargs = kwargs
|
|
43
|
+
if method == "hf":
|
|
44
|
+
if "model" not in kwargs:
|
|
45
|
+
raise ValueError("hf method need model")
|
|
46
|
+
model_name = kwargs["model"]
|
|
47
|
+
self.kwargs["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
|
|
48
|
+
self.kwargs["embed_model"] = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
|
49
|
+
self.kwargs.pop("model")
|
|
50
|
+
elif method == "mockembedder":
|
|
51
|
+
# 初始化 mockembedder
|
|
52
|
+
self.kwargs["embed_model"] = mockembedder.MockTextEmbedder(
|
|
53
|
+
model_name="mock-model",
|
|
54
|
+
fixed_dim=kwargs.get("fixed_dim", 128)
|
|
55
|
+
)
|
|
56
|
+
self.embed_fn = self._get_embed_function(method)
|
|
57
|
+
|
|
58
|
+
def set_dim(self, model_name):
|
|
59
|
+
"""
|
|
60
|
+
:param model_name:
|
|
61
|
+
:return:
|
|
62
|
+
"""
|
|
63
|
+
dimension_mapping = {
|
|
64
|
+
"mistral_embed": 1024,
|
|
65
|
+
"embed-multilingual-v3.0":1024,
|
|
66
|
+
"embed-english-v3.0": 1024,
|
|
67
|
+
"embed-english-light-v3.0": 384,
|
|
68
|
+
"embed-multilingual-light-v3.0": 384,
|
|
69
|
+
"embed-english-v2.0": 4096,
|
|
70
|
+
"embed-english-light-v2.0": 1024,
|
|
71
|
+
"embed-multilingual-v2.0": 768,
|
|
72
|
+
"jina-embeddings-v3":1024,
|
|
73
|
+
"BAAI/bge-m3":1024,
|
|
74
|
+
"sentence-transformers/all-MiniLM-L6-v2":384,
|
|
75
|
+
"mockembedder": 128
|
|
76
|
+
}
|
|
77
|
+
if model_name in dimension_mapping:
|
|
78
|
+
self.dim = dimension_mapping[model_name]
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f"<UNK> embedding <UNK>{model_name}")
|
|
81
|
+
|
|
82
|
+
def get_dim(self):
|
|
83
|
+
return self.dim
|
|
84
|
+
|
|
85
|
+
def _get_embed_function(self, method: str):
|
|
86
|
+
"""根据方法名返回对应的 embedding 函数"""
|
|
87
|
+
mapping = {
|
|
88
|
+
"openai": openai.openai_embed_sync,
|
|
89
|
+
"zhipu": zhipu.zhipu_embedding_sync,
|
|
90
|
+
"bedrock": bedrock.bedrock_embed_sync,
|
|
91
|
+
"hf": hf.hf_embed_sync,
|
|
92
|
+
"jina": jina.jina_embed_sync,
|
|
93
|
+
# "llama_index_impl": llama_index_impl.llama_index_embed,
|
|
94
|
+
"lollms": lollms.lollms_embed_sync,
|
|
95
|
+
"nvidia_openai": nvidia_openai.nvidia_openai_embed_sync,
|
|
96
|
+
"ollama": ollama.ollama_embed_sync,
|
|
97
|
+
"siliconcloud": siliconcloud.siliconcloud_embedding_sync,
|
|
98
|
+
"cohere": _cohere.cohere_embed_sync,
|
|
99
|
+
"mockembedder": lambda text, **kwargs: kwargs["embed_model"].encode(text).tolist(),
|
|
100
|
+
# "instructor": instructor.instructor_embed
|
|
101
|
+
}
|
|
102
|
+
if method not in mapping:
|
|
103
|
+
raise ValueError(f"不支持的 embedding 方法:{method}")
|
|
104
|
+
|
|
105
|
+
embed_fn = mapping[method]
|
|
106
|
+
|
|
107
|
+
return embed_fn
|
|
108
|
+
|
|
109
|
+
def _embed(self, text: str) -> list[float]:
|
|
110
|
+
"""
|
|
111
|
+
异步执行 embedding 操作
|
|
112
|
+
:param text: 要 embedding 的文本
|
|
113
|
+
:param kwargs: embedding 方法可能需要的额外参数
|
|
114
|
+
:return: embedding 后的结果
|
|
115
|
+
"""
|
|
116
|
+
return self.embed_fn(text, **self.kwargs)
|
|
117
|
+
|
|
118
|
+
def embed(self, text: str) -> list[float]:
|
|
119
|
+
return self._embed(text)
|
|
120
|
+
|
|
121
|
+
def encode(self, text: str) -> list[float]:
|
|
122
|
+
return self._embed(text)
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def method_name(self) -> str:
|
|
126
|
+
"""当前embedding方法名"""
|
|
127
|
+
return self.init_method
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def apply_embedding_model(name: str = "default",**kwargs) -> EmbeddingModel:
|
|
131
|
+
"""
|
|
132
|
+
usage 参见sage/api/model/operator_test.py
|
|
133
|
+
while name(method) = "hf", please set the param:model;
|
|
134
|
+
while name(method) = "openai",if you need call other APIs which are compatible with openai,set the params:base_url,api_key,model;
|
|
135
|
+
while name(method) = "jina/siliconcloud/cohere",please set the params:api_key,model;
|
|
136
|
+
Example:operator_test.py
|
|
137
|
+
"""
|
|
138
|
+
return EmbeddingModel(method=name,**kwargs)
|
|
139
|
+
|
|
140
|
+
def main():
|
|
141
|
+
embedding_model = EmbeddingModel(method="hf",model = "sentence-transformers/all-MiniLM-L6-v2")
|
|
142
|
+
for i in range(10):
|
|
143
|
+
start = time.time()
|
|
144
|
+
v = embedding_model.embed(f"{i} times ")
|
|
145
|
+
print(v)
|
|
146
|
+
end = time.time()
|
|
147
|
+
print(f"embedding time :{end-start}")
|
|
148
|
+
|
|
149
|
+
if __name__ =="__main__":
|
|
150
|
+
main()
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import copy
|
|
3
|
+
import os
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
import asyncio
|
|
6
|
+
import pipmaster as pm # Pipmaster for dynamic library install
|
|
7
|
+
|
|
8
|
+
# Dependencies should be installed via requirements.txt
|
|
9
|
+
# transformers, torch, tenacity, and numpy are required for this module
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
|
13
|
+
except ImportError:
|
|
14
|
+
raise ImportError(
|
|
15
|
+
"transformers package is required for HuggingFace embedding functionality. "
|
|
16
|
+
"Please install it via: pip install transformers"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
import torch
|
|
21
|
+
except ImportError:
|
|
22
|
+
raise ImportError(
|
|
23
|
+
"torch package is required for HuggingFace embedding functionality. "
|
|
24
|
+
"Please install it via: pip install torch"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
import tenacity
|
|
29
|
+
except ImportError:
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"tenacity package is required for HuggingFace embedding functionality. "
|
|
32
|
+
"Please install it via: pip install tenacity"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
import numpy
|
|
37
|
+
except ImportError:
|
|
38
|
+
raise ImportError(
|
|
39
|
+
"numpy package is required for HuggingFace embedding functionality. "
|
|
40
|
+
"Please install it via: pip install numpy"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
44
|
+
|
|
45
|
+
@lru_cache(maxsize=1)
|
|
46
|
+
def initialize_hf_model(model_name):
|
|
47
|
+
hf_tokenizer = AutoTokenizer.from_pretrained(
|
|
48
|
+
model_name, device_map="auto", trust_remote_code=True
|
|
49
|
+
)
|
|
50
|
+
hf_model = AutoModelForCausalLM.from_pretrained(
|
|
51
|
+
model_name, device_map="auto", trust_remote_code=True
|
|
52
|
+
)
|
|
53
|
+
if hf_tokenizer.pad_token is None:
|
|
54
|
+
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
|
55
|
+
|
|
56
|
+
return hf_model, hf_tokenizer
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
import torch
|
|
60
|
+
|
|
61
|
+
def hf_embed_sync(text: str, tokenizer, embed_model) -> list[float]:
|
|
62
|
+
"""
|
|
63
|
+
使用 HuggingFace 模型同步生成文本 embedding。
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
text (str): 输入文本
|
|
67
|
+
tokenizer: 已加载的 tokenizer
|
|
68
|
+
embed_model: 已加载的 PyTorch embedding 模型
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
list[float]: embedding 向量
|
|
72
|
+
"""
|
|
73
|
+
device = next(embed_model.parameters()).device
|
|
74
|
+
encoded_texts = tokenizer(
|
|
75
|
+
text, return_tensors="pt", padding=True, truncation=True
|
|
76
|
+
).to(device)
|
|
77
|
+
|
|
78
|
+
with torch.no_grad():
|
|
79
|
+
outputs = embed_model(
|
|
80
|
+
input_ids=encoded_texts["input_ids"],
|
|
81
|
+
attention_mask=encoded_texts["attention_mask"],
|
|
82
|
+
)
|
|
83
|
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
|
84
|
+
|
|
85
|
+
if embeddings.dtype == torch.bfloat16:
|
|
86
|
+
return embeddings.detach().to(torch.float32).cpu()[0].tolist()
|
|
87
|
+
else:
|
|
88
|
+
return embeddings.detach().cpu()[0].tolist()
|
|
89
|
+
|
|
90
|
+
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from InstructorEmbedding import INSTRUCTOR
|
|
3
|
+
from sentence_transformers import SentenceTransformer
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
async def instructor_embed(
|
|
7
|
+
texts: [str], model: str = "hkunlp/instructor-large"
|
|
8
|
+
) -> [np.array]:
|
|
9
|
+
_model = SentenceTransformer(model)
|
|
10
|
+
return _model.encode(texts)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
import pipmaster as pm # Pipmaster for dynamic library install
|
|
4
|
+
|
|
5
|
+
# Dependencies should be installed via requirements.txt
|
|
6
|
+
# tenacity is required for this module
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import tenacity
|
|
10
|
+
except ImportError:
|
|
11
|
+
raise ImportError(
|
|
12
|
+
"tenacity package is required for Jina embedding functionality. "
|
|
13
|
+
"Please install it via: pip install tenacity"
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import aiohttp
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"aiohttp package is required for Jina embedding functionality. "
|
|
21
|
+
"Please install it via: pip install aiohttp"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def fetch_data(url, headers, data):
|
|
28
|
+
async with aiohttp.ClientSession() as session:
|
|
29
|
+
async with session.post(url, headers=headers, json=data) as response:
|
|
30
|
+
response_json = await response.json()
|
|
31
|
+
data_list = response_json.get("data", [])
|
|
32
|
+
return data_list
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
async def jina_embed(
|
|
36
|
+
text: str,
|
|
37
|
+
dimensions: int = 1024,
|
|
38
|
+
late_chunking: bool = False,
|
|
39
|
+
base_url: str = None,
|
|
40
|
+
api_key: str = None,
|
|
41
|
+
model:str="jina-embeddings-v3"
|
|
42
|
+
) -> list[float]:
|
|
43
|
+
if api_key:
|
|
44
|
+
os.environ["JINA_API_KEY"] = api_key
|
|
45
|
+
url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
|
|
46
|
+
headers = {
|
|
47
|
+
"Content-Type": "application/json",
|
|
48
|
+
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
|
49
|
+
}
|
|
50
|
+
data = {
|
|
51
|
+
"model": f"{model}",
|
|
52
|
+
"normalized": True,
|
|
53
|
+
"embedding_type": "float",
|
|
54
|
+
"dimensions": f"{dimensions}",
|
|
55
|
+
"late_chunking": late_chunking,
|
|
56
|
+
"input": text,
|
|
57
|
+
}
|
|
58
|
+
data_list = await fetch_data(url, headers, data)
|
|
59
|
+
print(data_list)
|
|
60
|
+
return data_list[0]['embedding']
|
|
61
|
+
|
|
62
|
+
import os
|
|
63
|
+
import requests
|
|
64
|
+
|
|
65
|
+
def jina_embed_sync(
|
|
66
|
+
text: str,
|
|
67
|
+
dimensions: int = 1024,
|
|
68
|
+
late_chunking: bool = False,
|
|
69
|
+
base_url: str = None,
|
|
70
|
+
api_key: str = None,
|
|
71
|
+
model: str = "jina-embeddings-v3"
|
|
72
|
+
) -> list[float]:
|
|
73
|
+
"""
|
|
74
|
+
同步版本:调用 Jina AI embedding API 获取嵌入向量
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
text: 待嵌入的文本
|
|
78
|
+
dimensions: 嵌入维度
|
|
79
|
+
late_chunking: 是否开启 late chunking
|
|
80
|
+
base_url: 自定义 API 地址(可选)
|
|
81
|
+
api_key: Jina API 密钥
|
|
82
|
+
model: 使用的模型名
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
list[float]: 嵌入向量
|
|
86
|
+
"""
|
|
87
|
+
if api_key:
|
|
88
|
+
os.environ["JINA_API_KEY"] = api_key
|
|
89
|
+
|
|
90
|
+
url = base_url or "https://api.jina.ai/v1/embeddings"
|
|
91
|
+
headers = {
|
|
92
|
+
"Content-Type": "application/json",
|
|
93
|
+
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
|
94
|
+
}
|
|
95
|
+
payload = {
|
|
96
|
+
"model": model,
|
|
97
|
+
"normalized": True,
|
|
98
|
+
"embedding_type": "float",
|
|
99
|
+
"dimensions": dimensions,
|
|
100
|
+
"late_chunking": late_chunking,
|
|
101
|
+
"input": text,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
response = requests.post(url, headers=headers, json=payload)
|
|
106
|
+
response.raise_for_status()
|
|
107
|
+
data = response.json()
|
|
108
|
+
return data["data"][0]["embedding"]
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise RuntimeError(f"Jina API call failed: {str(e)}")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
|
|
3
|
+
if sys.version_info < (3, 9):
|
|
4
|
+
from typing import AsyncIterator
|
|
5
|
+
else:
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
import pipmaster as pm # Pipmaster for dynamic library install
|
|
8
|
+
|
|
9
|
+
# Dependencies should be installed via requirements.txt
|
|
10
|
+
# aiohttp and tenacity are required for this module
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import aiohttp
|
|
14
|
+
except ImportError:
|
|
15
|
+
raise ImportError(
|
|
16
|
+
"aiohttp package is required for Lollms embedding functionality. "
|
|
17
|
+
"Please install it via: pip install aiohttp"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import tenacity
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"tenacity package is required for Lollms embedding functionality. "
|
|
25
|
+
"Please install it via: pip install tenacity"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def lollms_embed(
|
|
31
|
+
text: str,
|
|
32
|
+
embed_model=None,
|
|
33
|
+
base_url="http://localhost:9600",
|
|
34
|
+
**kwargs,
|
|
35
|
+
) -> list:
|
|
36
|
+
"""
|
|
37
|
+
Generate embedding for a single text using lollms server.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
text: The string to embed
|
|
41
|
+
embed_model: Model name (not used directly as lollms uses configured vectorizer)
|
|
42
|
+
base_url: URL of the lollms server
|
|
43
|
+
**kwargs: Additional arguments passed to the request
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
list[float]: The embedding vector
|
|
47
|
+
"""
|
|
48
|
+
api_key = kwargs.pop("api_key", None)
|
|
49
|
+
headers = (
|
|
50
|
+
{"Content-Type": "application/json", "Authorization": api_key}
|
|
51
|
+
if api_key
|
|
52
|
+
else {"Content-Type": "application/json"}
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
async with aiohttp.ClientSession(headers=headers) as session:
|
|
56
|
+
request_data = {"text": text}
|
|
57
|
+
|
|
58
|
+
async with session.post(
|
|
59
|
+
f"{base_url}/lollms_embed",
|
|
60
|
+
json=request_data,
|
|
61
|
+
) as response:
|
|
62
|
+
result = await response.json()
|
|
63
|
+
return result["vector"]
|
|
64
|
+
|
|
65
|
+
import requests
|
|
66
|
+
|
|
67
|
+
def lollms_embed_sync(
|
|
68
|
+
text: str,
|
|
69
|
+
embed_model=None,
|
|
70
|
+
base_url="http://localhost:9600",
|
|
71
|
+
**kwargs,
|
|
72
|
+
) -> list[float]:
|
|
73
|
+
"""
|
|
74
|
+
同步版本:使用 lollms 本地服务生成 embedding。
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
text: 输入文本
|
|
78
|
+
embed_model: 模型名(未直接使用)
|
|
79
|
+
base_url: lollms 服务地址
|
|
80
|
+
**kwargs: 可选参数,例如 api_key
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
list[float]: 生成的向量
|
|
84
|
+
"""
|
|
85
|
+
api_key = kwargs.pop("api_key", None)
|
|
86
|
+
headers = {
|
|
87
|
+
"Content-Type": "application/json",
|
|
88
|
+
}
|
|
89
|
+
if api_key:
|
|
90
|
+
headers["Authorization"] = api_key
|
|
91
|
+
|
|
92
|
+
request_data = {"text": text}
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
response = requests.post(f"{base_url}/lollms_embed", json=request_data, headers=headers)
|
|
96
|
+
response.raise_for_status()
|
|
97
|
+
result = response.json()
|
|
98
|
+
return result["vector"]
|
|
99
|
+
except Exception as e:
|
|
100
|
+
raise RuntimeError(f"lollms embedding request failed: {str(e)}")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import hashlib
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
class MockTextEmbedder:
|
|
6
|
+
def __init__(self, model_name: str = 'mock-model', fixed_dim: int = 128):
|
|
7
|
+
"""Mock 模型,输出固定维度的随机张量(但相同文本输出一致)"""
|
|
8
|
+
self.fixed_dim = fixed_dim
|
|
9
|
+
# 用模型名作为随机种子的一部分,确保不同实例行为一致
|
|
10
|
+
self.seed = int(hashlib.sha256(model_name.encode()).hexdigest()[:8], 16)
|
|
11
|
+
# 添加 method_name 属性以兼容 MemoryManager
|
|
12
|
+
self.method_name = "mockembedder"
|
|
13
|
+
|
|
14
|
+
def encode(self, text: str, max_length: int = 512, stride: Optional[int] = None) -> torch.Tensor:
|
|
15
|
+
"""生成固定维度的伪嵌入(相同文本输出相同)"""
|
|
16
|
+
if not text.strip():
|
|
17
|
+
return torch.zeros(self.fixed_dim)
|
|
18
|
+
|
|
19
|
+
# 根据文本内容生成确定性随机数
|
|
20
|
+
text_seed = self.seed + int(hashlib.sha256(text.encode()).hexdigest()[:8], 16)
|
|
21
|
+
torch.manual_seed(text_seed)
|
|
22
|
+
|
|
23
|
+
# 生成随机向量(与原有代码维度逻辑一致)
|
|
24
|
+
if stride is None or len(text.split()) <= max_length:
|
|
25
|
+
return self._embed_single()
|
|
26
|
+
else:
|
|
27
|
+
return self._embed_with_sliding_window()
|
|
28
|
+
|
|
29
|
+
def _embed_single(self) -> torch.Tensor:
|
|
30
|
+
"""模拟单文本嵌入"""
|
|
31
|
+
embedding = torch.randn(384) # 模拟原始模型的中间维度
|
|
32
|
+
return self._adjust_dimension(embedding)
|
|
33
|
+
|
|
34
|
+
def _embed_with_sliding_window(self) -> torch.Tensor:
|
|
35
|
+
"""模拟长文本滑动窗口嵌入"""
|
|
36
|
+
embeddings = torch.stack([torch.randn(384) for _ in range(3)]) # 模拟3个窗口
|
|
37
|
+
return self._adjust_dimension(embeddings.mean(dim=0))
|
|
38
|
+
|
|
39
|
+
def _adjust_dimension(self, embedding: torch.Tensor) -> torch.Tensor:
|
|
40
|
+
"""保持与原代码一致的维度调整逻辑"""
|
|
41
|
+
if embedding.size(0) > self.fixed_dim:
|
|
42
|
+
return embedding[:self.fixed_dim]
|
|
43
|
+
elif embedding.size(0) < self.fixed_dim:
|
|
44
|
+
padding = torch.zeros(self.fixed_dim - embedding.size(0))
|
|
45
|
+
return torch.cat((embedding, padding))
|
|
46
|
+
return embedding
|