smartrouter 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smart_router/__init__.py +1 -0
- smart_router/assets/coffee_qr.png +0 -0
- smart_router/classifier/__init__.py +6 -0
- smart_router/classifier/difficulty_classifier.py +108 -0
- smart_router/classifier/embedding_matcher.py +137 -0
- smart_router/classifier/task_classifier.py +372 -0
- smart_router/classifier/types.py +23 -0
- smart_router/cli.py +823 -0
- smart_router/config/__init__.py +14 -0
- smart_router/config/loader.py +76 -0
- smart_router/config/schema.py +278 -0
- smart_router/config/v3_loader.py +21 -0
- smart_router/config/v3_schema.py +33 -0
- smart_router/gateway/__init__.py +16 -0
- smart_router/gateway/daemon.py +289 -0
- smart_router/gateway/server.py +205 -0
- smart_router/gateway/server_main.py +30 -0
- smart_router/misc/__init__.py +13 -0
- smart_router/misc/coffee_qr.py +200 -0
- smart_router/router/__init__.py +8 -0
- smart_router/router/plugin.py +194 -0
- smart_router/router/plugin_v3_adapter.py +128 -0
- smart_router/selector/__init__.py +13 -0
- smart_router/selector/model_selector.py +200 -0
- smart_router/selector/strategies.py +20 -0
- smart_router/selector/v3_selector.py +384 -0
- smart_router/templates/models.yaml +373 -0
- smart_router/templates/providers.yaml +63 -0
- smart_router/templates/routing.yaml +145 -0
- smart_router/utils/__init__.py +0 -0
- smart_router/utils/markers.py +45 -0
- smart_router/utils/token_counter.py +74 -0
- smartrouter-1.0.9.dist-info/METADATA +452 -0
- smartrouter-1.0.9.dist-info/RECORD +37 -0
- smartrouter-1.0.9.dist-info/WHEEL +4 -0
- smartrouter-1.0.9.dist-info/entry_points.txt +3 -0
- smartrouter-1.0.9.dist-info/licenses/LICENSE +674 -0
smart_router/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.0.9"
|
|
Binary file
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""难度评估器 - 独立评估难度"""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Dict, List, Optional, Literal
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class DifficultyResult:
|
|
10
|
+
"""难度评估结果"""
|
|
11
|
+
difficulty: Literal["easy", "medium", "hard"]
|
|
12
|
+
confidence: float
|
|
13
|
+
source: str # "rule" | "default"
|
|
14
|
+
matched_rule: Optional[str] = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DifficultyClassifier:
|
|
18
|
+
"""难度分类器"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, rules: List[Dict]):
|
|
21
|
+
"""
|
|
22
|
+
Args:
|
|
23
|
+
rules: 难度评估规则列表
|
|
24
|
+
"""
|
|
25
|
+
self.rules = sorted(rules, key=lambda x: x.get("priority", 1))
|
|
26
|
+
self.default_difficulty = "medium"
|
|
27
|
+
|
|
28
|
+
def classify(
|
|
29
|
+
self,
|
|
30
|
+
text: str,
|
|
31
|
+
task_type: Optional[str] = None
|
|
32
|
+
) -> DifficultyResult:
|
|
33
|
+
"""
|
|
34
|
+
评估难度
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
text: 用户输入文本
|
|
38
|
+
task_type: 任务类型(用于过滤特定规则)
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
DifficultyResult
|
|
42
|
+
"""
|
|
43
|
+
if not text:
|
|
44
|
+
return DifficultyResult(
|
|
45
|
+
difficulty=self.default_difficulty,
|
|
46
|
+
confidence=0.0,
|
|
47
|
+
source="default"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
text_lower = text.lower()
|
|
51
|
+
|
|
52
|
+
# 按优先级遍历规则
|
|
53
|
+
for rule in self.rules:
|
|
54
|
+
# 检查规则是否适用于当前任务类型
|
|
55
|
+
applies_to = rule.get("applies_to")
|
|
56
|
+
if applies_to and task_type not in applies_to:
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
condition = rule.get("condition", "")
|
|
60
|
+
|
|
61
|
+
# 解析条件
|
|
62
|
+
if self._match_condition(text_lower, condition):
|
|
63
|
+
return DifficultyResult(
|
|
64
|
+
difficulty=rule["difficulty"],
|
|
65
|
+
confidence=0.8, # 规则匹配的置信度
|
|
66
|
+
source="rule",
|
|
67
|
+
matched_rule=rule.get("description", condition)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# 默认返回 medium
|
|
71
|
+
return DifficultyResult(
|
|
72
|
+
difficulty=self.default_difficulty,
|
|
73
|
+
confidence=0.5,
|
|
74
|
+
source="default"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def _match_condition(self, text: str, condition: str) -> bool:
|
|
78
|
+
"""匹配条件"""
|
|
79
|
+
condition = condition.lower()
|
|
80
|
+
|
|
81
|
+
# 条件 1: length < N
|
|
82
|
+
if match := re.match(r'length\s*<\s*(\d+)', condition):
|
|
83
|
+
length_limit = int(match.group(1))
|
|
84
|
+
return len(text) < length_limit
|
|
85
|
+
|
|
86
|
+
# 条件 2: length > N
|
|
87
|
+
if match := re.match(r'length\s*>\s*(\d+)', condition):
|
|
88
|
+
length_limit = int(match.group(1))
|
|
89
|
+
return len(text) > length_limit
|
|
90
|
+
|
|
91
|
+
# 条件 3: keyword:xxx|yyy
|
|
92
|
+
if match := re.match(r'keyword:([\w\|]+)', condition):
|
|
93
|
+
keywords = match.group(1).split("|")
|
|
94
|
+
for kw in keywords:
|
|
95
|
+
if kw in text:
|
|
96
|
+
return True
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
# 条件 4: contains:xxx|yyy
|
|
100
|
+
if match := re.match(r'contains:([\w\|]+)', condition):
|
|
101
|
+
keywords = match.group(1).split("|")
|
|
102
|
+
for kw in keywords:
|
|
103
|
+
if kw in text:
|
|
104
|
+
return True
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
# 默认:字符串包含
|
|
108
|
+
return condition in text
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Simple Embedding Matcher - 基于词袋的相似度匹配
|
|
2
|
+
|
|
3
|
+
无需外部依赖,使用简单的词频向量和余弦相似度。
|
|
4
|
+
支持中英文混合文本。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
import math
|
|
9
|
+
from typing import Dict, List, Tuple, Optional
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SimpleEmbeddingMatcher:
|
|
13
|
+
"""简单 Embedding 匹配器
|
|
14
|
+
|
|
15
|
+
使用 TF(词频)向量和余弦相似度进行文本匹配。
|
|
16
|
+
适合小规模的示例匹配场景。
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, threshold: float = 0.3):
|
|
20
|
+
"""
|
|
21
|
+
Args:
|
|
22
|
+
threshold: 相似度阈值,超过此值认为匹配
|
|
23
|
+
"""
|
|
24
|
+
self.threshold = threshold
|
|
25
|
+
|
|
26
|
+
def _jaccard_similarity(self, set1: set, set2: set) -> float:
|
|
27
|
+
"""计算两个集合的 Jaccard 相似度"""
|
|
28
|
+
if not set1 or not set2:
|
|
29
|
+
return 0.0
|
|
30
|
+
intersection = len(set1 & set2)
|
|
31
|
+
union = len(set1 | set2)
|
|
32
|
+
if union == 0:
|
|
33
|
+
return 0.0
|
|
34
|
+
return intersection / union
|
|
35
|
+
|
|
36
|
+
def tokenize(self, text: str) -> List[str]:
|
|
37
|
+
"""分词:支持中英文
|
|
38
|
+
|
|
39
|
+
- 英文:按空格和标点分词
|
|
40
|
+
- 中文:按字符分词(简化处理)
|
|
41
|
+
"""
|
|
42
|
+
text = text.lower().strip()
|
|
43
|
+
if not text:
|
|
44
|
+
return []
|
|
45
|
+
|
|
46
|
+
# 提取英文单词
|
|
47
|
+
english_words = re.findall(r'[a-z]+', text)
|
|
48
|
+
|
|
49
|
+
# 提取中文字符(过滤掉标点和空格)
|
|
50
|
+
chinese_chars = re.findall(r'[\u4e00-\u9fff]', text)
|
|
51
|
+
|
|
52
|
+
return english_words + chinese_chars
|
|
53
|
+
|
|
54
|
+
def compute_tf(self, tokens: List[str]) -> Dict[str, float]:
|
|
55
|
+
"""计算词频(TF)向量"""
|
|
56
|
+
if not tokens:
|
|
57
|
+
return {}
|
|
58
|
+
|
|
59
|
+
freq = {}
|
|
60
|
+
for token in tokens:
|
|
61
|
+
freq[token] = freq.get(token, 0) + 1
|
|
62
|
+
|
|
63
|
+
# 归一化
|
|
64
|
+
total = len(tokens)
|
|
65
|
+
return {k: v / total for k, v in freq.items()}
|
|
66
|
+
|
|
67
|
+
def cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float:
|
|
68
|
+
"""计算两个向量的余弦相似度"""
|
|
69
|
+
if not vec1 or not vec2:
|
|
70
|
+
return 0.0
|
|
71
|
+
|
|
72
|
+
# 计算点积
|
|
73
|
+
dot_product = 0.0
|
|
74
|
+
for key in vec1:
|
|
75
|
+
if key in vec2:
|
|
76
|
+
dot_product += vec1[key] * vec2[key]
|
|
77
|
+
|
|
78
|
+
# 计算模长
|
|
79
|
+
norm1 = math.sqrt(sum(v ** 2 for v in vec1.values()))
|
|
80
|
+
norm2 = math.sqrt(sum(v ** 2 for v in vec2.values()))
|
|
81
|
+
|
|
82
|
+
if norm1 == 0 or norm2 == 0:
|
|
83
|
+
return 0.0
|
|
84
|
+
|
|
85
|
+
return dot_product / (norm1 * norm2)
|
|
86
|
+
|
|
87
|
+
def find_best_match(
|
|
88
|
+
self,
|
|
89
|
+
text: str,
|
|
90
|
+
examples_map: Dict[str, List[str]]
|
|
91
|
+
) -> Tuple[Optional[str], float]:
|
|
92
|
+
"""找到与输入文本最相似的示例类别
|
|
93
|
+
|
|
94
|
+
使用余弦相似度和 Jaccard 相似度的加权组合,
|
|
95
|
+
对短中文文本更鲁棒。
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
text: 输入文本
|
|
99
|
+
examples_map: {task_type: [example1, example2, ...]}
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
(最佳匹配的 task_type, 相似度分数)
|
|
103
|
+
如果没有超过阈值的匹配,返回 (None, 0.0)
|
|
104
|
+
"""
|
|
105
|
+
input_tokens = self.tokenize(text)
|
|
106
|
+
input_set = set(input_tokens)
|
|
107
|
+
input_vec = self.compute_tf(input_tokens)
|
|
108
|
+
|
|
109
|
+
if not input_vec:
|
|
110
|
+
return None, 0.0
|
|
111
|
+
|
|
112
|
+
best_type = None
|
|
113
|
+
best_score = 0.0
|
|
114
|
+
|
|
115
|
+
for task_type, examples in examples_map.items():
|
|
116
|
+
# 计算与所有示例的最大相似度(取最佳匹配)
|
|
117
|
+
max_score = 0.0
|
|
118
|
+
for example in examples:
|
|
119
|
+
example_tokens = self.tokenize(example)
|
|
120
|
+
example_vec = self.compute_tf(example_tokens)
|
|
121
|
+
|
|
122
|
+
if example_vec:
|
|
123
|
+
# 组合余弦相似度和 Jaccard 相似度
|
|
124
|
+
cos_sim = self.cosine_similarity(input_vec, example_vec)
|
|
125
|
+
jac_sim = self._jaccard_similarity(input_set, set(example_tokens))
|
|
126
|
+
# 加权组合:余弦相似度权重 0.6,Jaccard 权重 0.4
|
|
127
|
+
combined = cos_sim * 0.6 + jac_sim * 0.4
|
|
128
|
+
max_score = max(max_score, combined)
|
|
129
|
+
|
|
130
|
+
if max_score > best_score:
|
|
131
|
+
best_score = max_score
|
|
132
|
+
best_type = task_type
|
|
133
|
+
|
|
134
|
+
if best_score >= self.threshold:
|
|
135
|
+
return best_type, best_score
|
|
136
|
+
|
|
137
|
+
return None, best_score
|
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
"""Task Classifier - 任务分类器(重构版)
|
|
2
|
+
|
|
3
|
+
支持两种分类方式:
|
|
4
|
+
1. L1 Keywords: 基于关键词匹配(快速、精确)
|
|
5
|
+
2. L2 Embedding: 基于示例相似度匹配(模糊、泛化)
|
|
6
|
+
|
|
7
|
+
分类流程:
|
|
8
|
+
1. 先尝试 L1 keywords 匹配
|
|
9
|
+
2. 如果未命中,尝试 L2 embedding 相似度匹配
|
|
10
|
+
3. 如果仍未命中,返回默认类型
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import re
|
|
14
|
+
from typing import Dict, List, Optional
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
from .types import ClassificationResult, get_default_classification
|
|
18
|
+
from .difficulty_classifier import DifficultyClassifier, DifficultyResult
|
|
19
|
+
from .embedding_matcher import SimpleEmbeddingMatcher
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class TaskTypeResult:
|
|
24
|
+
"""任务类型分类结果"""
|
|
25
|
+
task_type: str
|
|
26
|
+
confidence: float
|
|
27
|
+
source: str # "keyword" | "embedding" | "default"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TaskTypeClassifier:
|
|
31
|
+
"""任务类型分类器(支持 Keywords + Embedding 匹配)
|
|
32
|
+
|
|
33
|
+
基于关键词匹配和示例相似度进行任务类型分类。
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, task_types: Dict[str, Dict]):
|
|
37
|
+
"""
|
|
38
|
+
Args:
|
|
39
|
+
task_types: {
|
|
40
|
+
task_type: {
|
|
41
|
+
keywords: [...], # 用于 L1 精确匹配
|
|
42
|
+
examples: [...], # 用于 L2 相似度匹配
|
|
43
|
+
description: ...
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
"""
|
|
47
|
+
self.task_types = task_types
|
|
48
|
+
self.default_type = "chat"
|
|
49
|
+
|
|
50
|
+
# 构建示例映射用于 Embedding 匹配
|
|
51
|
+
self.examples_map = {
|
|
52
|
+
task_type: config.get("examples", [])
|
|
53
|
+
for task_type, config in task_types.items()
|
|
54
|
+
if config.get("examples")
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
# 初始化 Embedding 匹配器(阈值 0.28 适合短中文文本)
|
|
58
|
+
self._embedding_matcher = SimpleEmbeddingMatcher(threshold=0.28)
|
|
59
|
+
|
|
60
|
+
def classify(self, messages: List[Dict]) -> TaskTypeResult:
|
|
61
|
+
"""
|
|
62
|
+
分类任务类型
|
|
63
|
+
|
|
64
|
+
流程:
|
|
65
|
+
1. L1: 关键词匹配(快速精确)
|
|
66
|
+
2. L2: 示例相似度匹配(模糊泛化)
|
|
67
|
+
3. 默认回退
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
messages: 消息列表
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
TaskTypeResult
|
|
74
|
+
"""
|
|
75
|
+
# 提取用户输入
|
|
76
|
+
user_content = ""
|
|
77
|
+
for msg in messages:
|
|
78
|
+
if msg.get("role") == "user":
|
|
79
|
+
user_content += msg.get("content", "") + " "
|
|
80
|
+
user_content = user_content.strip().lower()
|
|
81
|
+
|
|
82
|
+
if not user_content:
|
|
83
|
+
return TaskTypeResult(
|
|
84
|
+
task_type=self.default_type,
|
|
85
|
+
confidence=0.0,
|
|
86
|
+
source="default"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# L1: 关键词匹配
|
|
90
|
+
keyword_result = self._classify_by_keywords(user_content)
|
|
91
|
+
if keyword_result is not None:
|
|
92
|
+
return keyword_result
|
|
93
|
+
|
|
94
|
+
# L2: 示例相似度匹配
|
|
95
|
+
embedding_result = self._classify_by_embedding(user_content)
|
|
96
|
+
if embedding_result is not None:
|
|
97
|
+
return embedding_result
|
|
98
|
+
|
|
99
|
+
# 默认返回 chat
|
|
100
|
+
return TaskTypeResult(
|
|
101
|
+
task_type=self.default_type,
|
|
102
|
+
confidence=0.0,
|
|
103
|
+
source="default"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def _classify_by_keywords(self, text: str) -> Optional[TaskTypeResult]:
|
|
107
|
+
"""L1: 基于关键词匹配"""
|
|
108
|
+
best_match = None
|
|
109
|
+
best_score = 0.0
|
|
110
|
+
|
|
111
|
+
for task_type, config in self.task_types.items():
|
|
112
|
+
keywords = config.get("keywords", [])
|
|
113
|
+
score = self._calculate_keyword_score(text, keywords)
|
|
114
|
+
|
|
115
|
+
if score > best_score:
|
|
116
|
+
best_score = score
|
|
117
|
+
best_match = task_type
|
|
118
|
+
|
|
119
|
+
if best_match and best_score > 0:
|
|
120
|
+
return TaskTypeResult(
|
|
121
|
+
task_type=best_match,
|
|
122
|
+
confidence=min(best_score, 1.0),
|
|
123
|
+
source="keyword"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
def _classify_by_embedding(self, text: str) -> Optional[TaskTypeResult]:
|
|
129
|
+
"""L2: 基于示例相似度匹配"""
|
|
130
|
+
if not self.examples_map:
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
task_type, score = self._embedding_matcher.find_best_match(text, self.examples_map)
|
|
134
|
+
|
|
135
|
+
if task_type is not None:
|
|
136
|
+
return TaskTypeResult(
|
|
137
|
+
task_type=task_type,
|
|
138
|
+
confidence=min(score, 1.0),
|
|
139
|
+
source="embedding"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
def _calculate_keyword_score(self, text: str, keywords: List[str]) -> float:
|
|
145
|
+
"""计算关键词匹配分数
|
|
146
|
+
|
|
147
|
+
匹配多个关键词时给予额外奖励,以提高区分度。
|
|
148
|
+
"""
|
|
149
|
+
if not keywords:
|
|
150
|
+
return 0.0
|
|
151
|
+
|
|
152
|
+
matched = []
|
|
153
|
+
for keyword in keywords:
|
|
154
|
+
# 支持正则表达式
|
|
155
|
+
try:
|
|
156
|
+
if re.search(keyword, text, re.IGNORECASE):
|
|
157
|
+
matched.append(keyword)
|
|
158
|
+
except re.error:
|
|
159
|
+
# 普通字符串匹配
|
|
160
|
+
if keyword.lower() in text:
|
|
161
|
+
matched.append(keyword)
|
|
162
|
+
|
|
163
|
+
if not matched:
|
|
164
|
+
return 0.0
|
|
165
|
+
|
|
166
|
+
# 基础分:匹配率
|
|
167
|
+
base_score = len(matched) / len(keywords)
|
|
168
|
+
# 多关键词匹配 bonus:每多匹配一个关键词加 0.15
|
|
169
|
+
bonus = 0.15 * (len(matched) - 1)
|
|
170
|
+
|
|
171
|
+
return base_score + bonus
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# 默认难度评估规则
|
|
175
|
+
# 优先级:keyword (1) > length (2),确保语义关键词优先于长度判断
|
|
176
|
+
DEFAULT_DIFFICULTY_RULES = [
|
|
177
|
+
{
|
|
178
|
+
"condition": "keyword:复杂|详细|深入|架构|设计模式|优化|重构|性能|并发|分布式",
|
|
179
|
+
"difficulty": "hard",
|
|
180
|
+
"description": "复杂关键词",
|
|
181
|
+
"priority": 1
|
|
182
|
+
},
|
|
183
|
+
{
|
|
184
|
+
"condition": "keyword:step by step|一步步|详细步骤|完整实现|全面分析",
|
|
185
|
+
"difficulty": "hard",
|
|
186
|
+
"description": "深度分析关键词",
|
|
187
|
+
"priority": 1
|
|
188
|
+
},
|
|
189
|
+
{
|
|
190
|
+
"condition": "keyword:简单|easy|快速|简述|简短|总结一下",
|
|
191
|
+
"difficulty": "easy",
|
|
192
|
+
"description": "简单关键词",
|
|
193
|
+
"priority": 1
|
|
194
|
+
},
|
|
195
|
+
{
|
|
196
|
+
"condition": "length > 500",
|
|
197
|
+
"difficulty": "hard",
|
|
198
|
+
"description": "长文本",
|
|
199
|
+
"priority": 2
|
|
200
|
+
},
|
|
201
|
+
{
|
|
202
|
+
"condition": "length < 20",
|
|
203
|
+
"difficulty": "easy",
|
|
204
|
+
"description": "极短文本",
|
|
205
|
+
"priority": 2
|
|
206
|
+
}
|
|
207
|
+
]
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class TaskClassifier:
|
|
211
|
+
"""统一任务分类器(重构版)
|
|
212
|
+
|
|
213
|
+
组合任务类型分类和动态难度评估,支持 keywords 和 examples。
|
|
214
|
+
|
|
215
|
+
分类流程:
|
|
216
|
+
1. L1 Keywords: 基于配置中的 keywords 进行快速匹配
|
|
217
|
+
2. L2 Embedding: 基于配置中的 examples 进行相似度匹配
|
|
218
|
+
3. L3 Rules: 回退到旧的规则引擎(向后兼容)
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(
|
|
222
|
+
self,
|
|
223
|
+
rules: List[Dict],
|
|
224
|
+
embedding_config: Dict,
|
|
225
|
+
task_configs: Optional[Dict[str, Dict]] = None
|
|
226
|
+
):
|
|
227
|
+
"""
|
|
228
|
+
Args:
|
|
229
|
+
rules: 分类规则列表(向后兼容)
|
|
230
|
+
embedding_config: Embedding 配置
|
|
231
|
+
task_configs: 任务配置字典 {
|
|
232
|
+
task_type: {
|
|
233
|
+
keywords: [...],
|
|
234
|
+
examples: [...]
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
"""
|
|
238
|
+
self.rules = rules
|
|
239
|
+
self.embedding_config = embedding_config
|
|
240
|
+
self.default_type = "chat"
|
|
241
|
+
self.default_difficulty = "medium"
|
|
242
|
+
|
|
243
|
+
# 构建 task_types 供 TaskTypeClassifier 使用
|
|
244
|
+
task_types = {}
|
|
245
|
+
|
|
246
|
+
if task_configs:
|
|
247
|
+
# 使用新的 task_configs(包含 keywords 和 examples)
|
|
248
|
+
for task_type, config in task_configs.items():
|
|
249
|
+
task_types[task_type] = {
|
|
250
|
+
"keywords": config.get("keywords", []),
|
|
251
|
+
"examples": config.get("examples", []),
|
|
252
|
+
"description": config.get("description", "")
|
|
253
|
+
}
|
|
254
|
+
else:
|
|
255
|
+
# 向后兼容:从 rules 构建(旧行为)
|
|
256
|
+
for rule in rules:
|
|
257
|
+
task_type = rule.get("task_type", "")
|
|
258
|
+
if task_type:
|
|
259
|
+
if task_type not in task_types:
|
|
260
|
+
task_types[task_type] = {"keywords": [], "examples": [], "description": ""}
|
|
261
|
+
pattern = rule.get("pattern", "")
|
|
262
|
+
if pattern:
|
|
263
|
+
task_types[task_type]["keywords"].append(pattern)
|
|
264
|
+
|
|
265
|
+
self._type_classifier = TaskTypeClassifier(task_types)
|
|
266
|
+
|
|
267
|
+
# 初始化动态难度评估器
|
|
268
|
+
self._difficulty_classifier = DifficultyClassifier(DEFAULT_DIFFICULTY_RULES)
|
|
269
|
+
|
|
270
|
+
def classify(self, messages: List[Dict]) -> ClassificationResult:
|
|
271
|
+
"""
|
|
272
|
+
分类任务
|
|
273
|
+
|
|
274
|
+
流程:
|
|
275
|
+
1. 提取用户输入并拼接
|
|
276
|
+
2. L1/L2 任务类型分类(keywords + embedding)
|
|
277
|
+
3. 动态难度评估
|
|
278
|
+
4. 多轮对话提升难度档位
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
messages: 消息列表
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
ClassificationResult
|
|
285
|
+
"""
|
|
286
|
+
# 提取用户输入
|
|
287
|
+
user_content = ""
|
|
288
|
+
user_message_count = 0
|
|
289
|
+
for msg in messages:
|
|
290
|
+
if msg.get("role") == "user":
|
|
291
|
+
content = msg.get("content", "")
|
|
292
|
+
if content:
|
|
293
|
+
user_content += content + " "
|
|
294
|
+
user_message_count += 1
|
|
295
|
+
user_content = user_content.strip().lower()
|
|
296
|
+
|
|
297
|
+
if not user_content:
|
|
298
|
+
return get_default_classification()
|
|
299
|
+
|
|
300
|
+
# 任务类型分类(L1 keywords / L2 embedding)
|
|
301
|
+
type_result = self._type_classifier.classify(messages)
|
|
302
|
+
task_type = type_result.task_type
|
|
303
|
+
|
|
304
|
+
# 动态难度评估
|
|
305
|
+
difficulty_result = self._difficulty_classifier.classify(user_content, task_type=task_type)
|
|
306
|
+
difficulty = difficulty_result.difficulty
|
|
307
|
+
|
|
308
|
+
# 多轮对话提升难度:超过 3 轮 user 消息,难度升一档
|
|
309
|
+
if user_message_count > 3:
|
|
310
|
+
difficulty = self._bump_difficulty(difficulty)
|
|
311
|
+
|
|
312
|
+
# 计算置信度
|
|
313
|
+
confidence = type_result.confidence
|
|
314
|
+
if type_result.source == "keyword":
|
|
315
|
+
confidence = max(confidence, 0.9)
|
|
316
|
+
elif type_result.source == "embedding":
|
|
317
|
+
confidence = max(confidence, 0.6)
|
|
318
|
+
|
|
319
|
+
# 确定 source(向后兼容)
|
|
320
|
+
# 如果任务类型已明确分类(keyword/embedding),保留其 source
|
|
321
|
+
# 只有当任务类型是默认回退且难度评估器命中了规则时,才使用 dynamic_difficulty
|
|
322
|
+
if type_result.source == "default" and difficulty_result.source == "rule":
|
|
323
|
+
source = "dynamic_difficulty"
|
|
324
|
+
else:
|
|
325
|
+
source = type_result.source
|
|
326
|
+
|
|
327
|
+
return ClassificationResult(
|
|
328
|
+
task_type=task_type,
|
|
329
|
+
estimated_difficulty=difficulty,
|
|
330
|
+
confidence=confidence,
|
|
331
|
+
source=source
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
DIFFICULTY_ORDER = ["easy", "medium", "hard", "expert"]
|
|
335
|
+
|
|
336
|
+
def _adjust_difficulty(self, difficulty: str, delta: int) -> str:
|
|
337
|
+
"""按档位调整难度
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
difficulty: 当前难度
|
|
341
|
+
delta: 调整量(正数提升,负数降低)
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
调整后的难度
|
|
345
|
+
"""
|
|
346
|
+
order = self.DIFFICULTY_ORDER
|
|
347
|
+
try:
|
|
348
|
+
idx = order.index(difficulty)
|
|
349
|
+
new_idx = idx + delta
|
|
350
|
+
if 0 <= new_idx < len(order):
|
|
351
|
+
return order[new_idx]
|
|
352
|
+
except ValueError:
|
|
353
|
+
pass
|
|
354
|
+
return difficulty
|
|
355
|
+
|
|
356
|
+
def _bump_difficulty(self, difficulty: str) -> str:
|
|
357
|
+
"""提升难度一档"""
|
|
358
|
+
return self._adjust_difficulty(difficulty, 1)
|
|
359
|
+
|
|
360
|
+
def _lower_difficulty(self, difficulty: str) -> str:
|
|
361
|
+
"""降低难度一档"""
|
|
362
|
+
return self._adjust_difficulty(difficulty, -1)
|
|
363
|
+
|
|
364
|
+
def _match_pattern(self, text: str, pattern: str) -> bool:
|
|
365
|
+
"""匹配正则模式"""
|
|
366
|
+
try:
|
|
367
|
+
return bool(re.search(pattern, text, re.IGNORECASE))
|
|
368
|
+
except re.error:
|
|
369
|
+
return pattern.lower() in text
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
__all__ = ["TaskClassifier", "TaskTypeClassifier", "TaskTypeResult"]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Classifier Types"""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class ClassificationResult:
|
|
9
|
+
"""分类结果"""
|
|
10
|
+
task_type: str
|
|
11
|
+
estimated_difficulty: str
|
|
12
|
+
confidence: float
|
|
13
|
+
source: str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_default_classification() -> ClassificationResult:
|
|
17
|
+
"""获取默认分类结果"""
|
|
18
|
+
return ClassificationResult(
|
|
19
|
+
task_type="chat",
|
|
20
|
+
estimated_difficulty="medium",
|
|
21
|
+
confidence=0.0,
|
|
22
|
+
source="default"
|
|
23
|
+
)
|