autosnippet 3.0.3 → 3.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +85 -240
- package/dashboard/dist/assets/{icons-Cdq22n2i.js → icons-eQ_rWCus.js} +97 -102
- package/dashboard/dist/assets/index-B3Nnkdxi.js +133 -0
- package/dashboard/dist/assets/index-BFNDAqh3.css +1 -0
- package/dashboard/dist/index.html +3 -3
- package/lib/core/AstAnalyzer.js +2 -2
- package/lib/core/ast/ensure-grammars.js +2 -0
- package/lib/core/ast/index.js +8 -0
- package/lib/core/ast/lang-rust.js +695 -0
- package/lib/core/discovery/PythonDiscoverer.js +3 -0
- package/lib/core/discovery/RustDiscoverer.js +467 -0
- package/lib/core/discovery/index.js +3 -0
- package/lib/core/enhancement/django-enhancement.js +169 -3
- package/lib/core/enhancement/fastapi-enhancement.js +149 -3
- package/lib/core/enhancement/go-grpc-enhancement.js +4 -0
- package/lib/core/enhancement/go-web-enhancement.js +6 -0
- package/lib/core/enhancement/index.js +5 -0
- package/lib/core/enhancement/langchain-enhancement.js +233 -0
- package/lib/core/enhancement/ml-enhancement.js +265 -0
- package/lib/core/enhancement/nextjs-enhancement.js +219 -0
- package/lib/core/enhancement/node-server-enhancement.js +178 -4
- package/lib/core/enhancement/react-enhancement.js +165 -4
- package/lib/core/enhancement/rust-tokio-enhancement.js +231 -0
- package/lib/core/enhancement/rust-web-enhancement.js +256 -0
- package/lib/core/enhancement/spring-enhancement.js +2 -0
- package/lib/core/enhancement/vue-enhancement.js +143 -2
- package/lib/external/ai/AiProvider.js +45 -6
- package/lib/external/mcp/handlers/bootstrap/skills.js +2 -0
- package/lib/external/mcp/handlers/bootstrap.js +33 -9
- package/lib/external/mcp/handlers/guard.js +42 -0
- package/lib/http/routes/candidates.js +7 -1
- package/lib/service/chat/ChatAgent.js +1 -0
- package/lib/service/chat/tools.js +5 -1
- package/lib/service/guard/ComplianceReporter.js +20 -7
- package/lib/service/guard/GuardCheckEngine.js +156 -5
- package/lib/service/guard/SourceFileCollector.js +15 -0
- package/package.json +28 -6
- package/skills/autosnippet-coldstart/SKILL.md +4 -2
- package/skills/autosnippet-concepts/SKILL.md +5 -3
- package/skills/autosnippet-reference-rust/SKILL.md +401 -0
- package/skills/autosnippet-structure/SKILL.md +1 -1
- package/templates/recipes-setup/README.md +2 -2
- package/templates/recipes-setup/_template.md +1 -1
- package/dashboard/dist/assets/index-ClkyPkDX.js +0 -133
- package/dashboard/dist/assets/index-t4QrJwv1.css +0 -1
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* LangChain / Agent Enhancement Pack
|
|
3
|
+
* 条件: { languages: ['python'], frameworks: ['langchain'] }
|
|
4
|
+
*
|
|
5
|
+
* 覆盖 LLM Agent 开发生态:
|
|
6
|
+
* - LangChain Chain / Agent / Tool
|
|
7
|
+
* - RAG Pipeline (Retriever → Splitter → Embedding → VectorStore)
|
|
8
|
+
* - Prompt Template / Output Parser
|
|
9
|
+
* - LlamaIndex Query Engine
|
|
10
|
+
* - 多轮对话 Memory
|
|
11
|
+
* - Streaming / Callback
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
import { EnhancementPack } from './EnhancementPack.js';
|
|
15
|
+
|
|
16
|
+
class LangChainEnhancement extends EnhancementPack {
|
|
17
|
+
get id() {
|
|
18
|
+
return 'python-langchain';
|
|
19
|
+
}
|
|
20
|
+
get displayName() {
|
|
21
|
+
return 'LangChain / Agent Enhancement';
|
|
22
|
+
}
|
|
23
|
+
get conditions() {
|
|
24
|
+
return { languages: ['python'], frameworks: ['langchain'] };
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
getExtraDimensions() {
|
|
28
|
+
return [
|
|
29
|
+
{
|
|
30
|
+
id: 'langchain-chain-scan',
|
|
31
|
+
label: 'Chain / Agent 分析',
|
|
32
|
+
guide:
|
|
33
|
+
'LangChain Chain 拓扑分析 — LCEL (RunnableSequence | RunnableParallel) 链路、Agent 工具注册、AgentExecutor 配置、工具函数实现 (@tool decorator)',
|
|
34
|
+
knowledgeTypes: ['architecture', 'code-pattern'],
|
|
35
|
+
skillWorthy: true,
|
|
36
|
+
dualOutput: true,
|
|
37
|
+
skillMeta: {
|
|
38
|
+
name: 'project-langchain-chains',
|
|
39
|
+
description:
|
|
40
|
+
'LangChain chain topology, LCEL pipelines and agent tool registrations (auto-generated by enhancement)',
|
|
41
|
+
},
|
|
42
|
+
},
|
|
43
|
+
{
|
|
44
|
+
id: 'langchain-rag-scan',
|
|
45
|
+
label: 'RAG Pipeline 分析',
|
|
46
|
+
guide:
|
|
47
|
+
'RAG 检索增强生成管道分析 — Document Loader / Text Splitter / Embedding Model / VectorStore (Chroma/FAISS/Pinecone) 选型、Retriever 配置 (search_type/k/score_threshold)、Reranking 策略',
|
|
48
|
+
knowledgeTypes: ['architecture', 'code-pattern'],
|
|
49
|
+
skillWorthy: true,
|
|
50
|
+
dualOutput: true,
|
|
51
|
+
skillMeta: {
|
|
52
|
+
name: 'project-langchain-rag',
|
|
53
|
+
description:
|
|
54
|
+
'RAG pipeline — document loading, splitting, embedding and vector retrieval (auto-generated by enhancement)',
|
|
55
|
+
},
|
|
56
|
+
},
|
|
57
|
+
{
|
|
58
|
+
id: 'langchain-prompt-scan',
|
|
59
|
+
label: 'Prompt / Output 模式分析',
|
|
60
|
+
guide:
|
|
61
|
+
'Prompt 工程分析 — PromptTemplate / ChatPromptTemplate 结构、Few-shot 示例管理、Output Parser (JSON/Pydantic/Structured)、System Message 设计模式',
|
|
62
|
+
knowledgeTypes: ['code-pattern'],
|
|
63
|
+
skillWorthy: true,
|
|
64
|
+
dualOutput: true,
|
|
65
|
+
skillMeta: {
|
|
66
|
+
name: 'project-langchain-prompts',
|
|
67
|
+
description:
|
|
68
|
+
'Prompt templates, few-shot examples and output parser configurations (auto-generated by enhancement)',
|
|
69
|
+
},
|
|
70
|
+
},
|
|
71
|
+
];
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
getGuardRules() {
|
|
75
|
+
return [
|
|
76
|
+
{
|
|
77
|
+
ruleId: 'langchain-prompt-injection',
|
|
78
|
+
category: 'safety',
|
|
79
|
+
dimension: 'file',
|
|
80
|
+
severity: 'warning',
|
|
81
|
+
languages: ['python'],
|
|
82
|
+
pattern: /(?:PromptTemplate|ChatPromptTemplate)[\s\S]*?(?:f['"]|\.format\s*\()/,
|
|
83
|
+
message: 'Prompt 中直接拼接用户输入可能导致 Prompt Injection — 使用 input_variables 参数化',
|
|
84
|
+
},
|
|
85
|
+
{
|
|
86
|
+
ruleId: 'langchain-no-bare-invoke',
|
|
87
|
+
category: 'correctness',
|
|
88
|
+
dimension: 'file',
|
|
89
|
+
severity: 'info',
|
|
90
|
+
languages: ['python'],
|
|
91
|
+
pattern: /\.invoke\s*\([^)]*\)\s*(?![\s\S]*?(?:try|except|catch))/,
|
|
92
|
+
message: 'Chain/Agent invoke 应包含错误处理 — LLM 调用可能超时/限流/返回异常格式',
|
|
93
|
+
},
|
|
94
|
+
{
|
|
95
|
+
ruleId: 'langchain-token-budget',
|
|
96
|
+
category: 'performance',
|
|
97
|
+
dimension: 'file',
|
|
98
|
+
severity: 'info',
|
|
99
|
+
languages: ['python'],
|
|
100
|
+
pattern: /max_tokens\s*=\s*(?:None|0)/,
|
|
101
|
+
message: '建议设置合理的 max_tokens 限制 — 防止意外高额 API 费用',
|
|
102
|
+
},
|
|
103
|
+
{
|
|
104
|
+
ruleId: 'langchain-hardcoded-api-key',
|
|
105
|
+
category: 'safety',
|
|
106
|
+
dimension: 'file',
|
|
107
|
+
severity: 'error',
|
|
108
|
+
languages: ['python'],
|
|
109
|
+
pattern: /(?:api_key|openai_api_key|anthropic_api_key)\s*=\s*['"][^'"]+['"]/i,
|
|
110
|
+
message: 'API Key 不应硬编码在代码中 — 使用环境变量 (os.environ) 或 .env 文件',
|
|
111
|
+
},
|
|
112
|
+
];
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
detectPatterns(astSummary) {
|
|
116
|
+
const patterns = [];
|
|
117
|
+
|
|
118
|
+
// ── Chain / Runnable classes ──
|
|
119
|
+
for (const cls of astSummary.classes || []) {
|
|
120
|
+
if (cls.superclass && /Runnable|Chain|BaseTool|BaseRetriever/.test(cls.superclass)) {
|
|
121
|
+
patterns.push({
|
|
122
|
+
type: 'langchain-chain',
|
|
123
|
+
className: cls.name,
|
|
124
|
+
line: cls.line,
|
|
125
|
+
confidence: 0.9,
|
|
126
|
+
});
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// ── @tool decorated functions ──
|
|
131
|
+
for (const m of astSummary.methods || []) {
|
|
132
|
+
if (m.decorators?.some((d) => /@tool/.test(d))) {
|
|
133
|
+
patterns.push({
|
|
134
|
+
type: 'langchain-tool',
|
|
135
|
+
methodName: m.name,
|
|
136
|
+
line: m.line,
|
|
137
|
+
confidence: 0.95,
|
|
138
|
+
});
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
// ── Agent/RAG setup functions ──
|
|
143
|
+
for (const m of astSummary.methods || []) {
|
|
144
|
+
if (!m.className && m.name) {
|
|
145
|
+
const nameLower = m.name.toLowerCase();
|
|
146
|
+
if (
|
|
147
|
+
nameLower.includes('create_agent') ||
|
|
148
|
+
nameLower.includes('build_chain') ||
|
|
149
|
+
nameLower.includes('setup_rag') ||
|
|
150
|
+
nameLower.includes('create_retriever') ||
|
|
151
|
+
nameLower.includes('get_llm') ||
|
|
152
|
+
nameLower.includes('create_chain') ||
|
|
153
|
+
nameLower.includes('build_graph')
|
|
154
|
+
) {
|
|
155
|
+
patterns.push({
|
|
156
|
+
type: 'langchain-setup',
|
|
157
|
+
methodName: m.name,
|
|
158
|
+
line: m.line,
|
|
159
|
+
confidence: 0.8,
|
|
160
|
+
});
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
// ── Callback / Handler classes ──
|
|
166
|
+
for (const cls of astSummary.classes || []) {
|
|
167
|
+
if (cls.superclass && /CallbackHandler|BaseCallbackHandler/.test(cls.superclass)) {
|
|
168
|
+
patterns.push({
|
|
169
|
+
type: 'langchain-callback',
|
|
170
|
+
className: cls.name,
|
|
171
|
+
line: cls.line,
|
|
172
|
+
confidence: 0.9,
|
|
173
|
+
});
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
// ── Output parser classes ──
|
|
178
|
+
for (const cls of astSummary.classes || []) {
|
|
179
|
+
if (cls.superclass && /OutputParser|BaseOutputParser/.test(cls.superclass)) {
|
|
180
|
+
patterns.push({
|
|
181
|
+
type: 'langchain-output-parser',
|
|
182
|
+
className: cls.name,
|
|
183
|
+
line: cls.line,
|
|
184
|
+
confidence: 0.9,
|
|
185
|
+
});
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// ── LangGraph nodes/edges ──
|
|
190
|
+
for (const m of astSummary.methods || []) {
|
|
191
|
+
if (!m.className && m.name) {
|
|
192
|
+
const nameLower = m.name.toLowerCase();
|
|
193
|
+
if (
|
|
194
|
+
nameLower.includes('node') ||
|
|
195
|
+
nameLower.includes('should_continue') ||
|
|
196
|
+
nameLower.includes('route_')
|
|
197
|
+
) {
|
|
198
|
+
patterns.push({
|
|
199
|
+
type: 'langgraph-node',
|
|
200
|
+
methodName: m.name,
|
|
201
|
+
line: m.line,
|
|
202
|
+
confidence: 0.5,
|
|
203
|
+
});
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
// ── LangChain ecosystem imports ──
|
|
209
|
+
const lcImports = (astSummary.imports || []).filter(
|
|
210
|
+
(imp) =>
|
|
211
|
+
imp.includes('langchain') ||
|
|
212
|
+
imp.includes('langgraph') ||
|
|
213
|
+
imp.includes('langsmith') ||
|
|
214
|
+
imp.includes('llama_index') ||
|
|
215
|
+
imp.includes('chromadb') ||
|
|
216
|
+
imp.includes('pinecone') ||
|
|
217
|
+
imp.includes('faiss') ||
|
|
218
|
+
imp.includes('openai') ||
|
|
219
|
+
imp.includes('anthropic')
|
|
220
|
+
);
|
|
221
|
+
if (lcImports.length > 0) {
|
|
222
|
+
patterns.push({
|
|
223
|
+
type: 'langchain-ecosystem-usage',
|
|
224
|
+
importCount: lcImports.length,
|
|
225
|
+
confidence: 0.85,
|
|
226
|
+
});
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
return patterns;
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
export const pack = new LangChainEnhancement();
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* AI/ML Enhancement Pack
|
|
3
|
+
* 条件: { languages: ['python'], frameworks: ['ml'] }
|
|
4
|
+
*
|
|
5
|
+
* 覆盖 PyTorch / TensorFlow / HuggingFace 机器学习生态:
|
|
6
|
+
* - nn.Module 模型架构
|
|
7
|
+
* - Training Loop 模式 (optimizer.zero_grad → loss.backward → optimizer.step)
|
|
8
|
+
* - DataLoader / Dataset
|
|
9
|
+
* - HuggingFace Trainer / Pipeline
|
|
10
|
+
* - 模型保存/加载 (state_dict / safetensors)
|
|
11
|
+
* - 分布式训练 (DDP / FSDP)
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
import { EnhancementPack } from './EnhancementPack.js';
|
|
15
|
+
|
|
16
|
+
class MLEnhancement extends EnhancementPack {
|
|
17
|
+
get id() {
|
|
18
|
+
return 'python-ml';
|
|
19
|
+
}
|
|
20
|
+
get displayName() {
|
|
21
|
+
return 'Python AI/ML (PyTorch/HuggingFace) Enhancement';
|
|
22
|
+
}
|
|
23
|
+
get conditions() {
|
|
24
|
+
return { languages: ['python'], frameworks: ['ml'] };
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
getExtraDimensions() {
|
|
28
|
+
return [
|
|
29
|
+
{
|
|
30
|
+
id: 'ml-model-architecture-scan',
|
|
31
|
+
label: '模型架构分析',
|
|
32
|
+
guide:
|
|
33
|
+
'nn.Module 子类分析 — 层级结构 (forward 方法调用链)、自定义 Layer、残差连接/Attention 模式、参数量估算、模型注册表',
|
|
34
|
+
knowledgeTypes: ['architecture', 'code-pattern'],
|
|
35
|
+
skillWorthy: true,
|
|
36
|
+
dualOutput: true,
|
|
37
|
+
skillMeta: {
|
|
38
|
+
name: 'project-ml-models',
|
|
39
|
+
description:
|
|
40
|
+
'PyTorch model architectures, custom layers and forward pass patterns (auto-generated by enhancement)',
|
|
41
|
+
},
|
|
42
|
+
},
|
|
43
|
+
{
|
|
44
|
+
id: 'ml-training-pipeline-scan',
|
|
45
|
+
label: 'Training Pipeline 分析',
|
|
46
|
+
guide:
|
|
47
|
+
'Training 流程分析 — Training Loop 结构 (epoch → batch → forward → loss → backward → step)、learning rate scheduler、gradient clipping/accumulation、早停策略、checkpoint 保存/恢复、HuggingFace Trainer 配置',
|
|
48
|
+
knowledgeTypes: ['code-pattern'],
|
|
49
|
+
skillWorthy: true,
|
|
50
|
+
dualOutput: true,
|
|
51
|
+
skillMeta: {
|
|
52
|
+
name: 'project-ml-training',
|
|
53
|
+
description:
|
|
54
|
+
'ML training pipeline — training loops, schedulers, checkpointing and HF Trainer (auto-generated by enhancement)',
|
|
55
|
+
},
|
|
56
|
+
},
|
|
57
|
+
{
|
|
58
|
+
id: 'ml-data-pipeline-scan',
|
|
59
|
+
label: '数据管道分析',
|
|
60
|
+
guide:
|
|
61
|
+
'数据处理管道分析 — Dataset/DataLoader 实现、数据增强 (transforms)、tokenizer 配置、特征工程、数据拆分策略 (train/val/test)',
|
|
62
|
+
knowledgeTypes: ['code-pattern', 'architecture'],
|
|
63
|
+
skillWorthy: true,
|
|
64
|
+
dualOutput: true,
|
|
65
|
+
skillMeta: {
|
|
66
|
+
name: 'project-ml-data',
|
|
67
|
+
description:
|
|
68
|
+
'ML data pipelines — Dataset/DataLoader, transforms and tokenization (auto-generated by enhancement)',
|
|
69
|
+
},
|
|
70
|
+
},
|
|
71
|
+
];
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
getGuardRules() {
|
|
75
|
+
return [
|
|
76
|
+
{
|
|
77
|
+
ruleId: 'ml-no-eval-in-train',
|
|
78
|
+
category: 'correctness',
|
|
79
|
+
dimension: 'file',
|
|
80
|
+
severity: 'warning',
|
|
81
|
+
languages: ['python'],
|
|
82
|
+
pattern: /model\.train\(\)[\s\S]*?model\.eval\(\)[\s\S]*?loss\.backward/,
|
|
83
|
+
message: '训练循环中意外切换到 eval 模式后执行 backward — 确保 model.train() 在训练阶段',
|
|
84
|
+
},
|
|
85
|
+
{
|
|
86
|
+
ruleId: 'ml-device-mismatch',
|
|
87
|
+
category: 'correctness',
|
|
88
|
+
dimension: 'file',
|
|
89
|
+
severity: 'warning',
|
|
90
|
+
languages: ['python'],
|
|
91
|
+
pattern: /\.to\s*\(\s*['"](?:cuda|cpu)['"]\s*\)/,
|
|
92
|
+
message: '硬编码 device 字符串 — 建议使用 torch.device() 变量统一管理,支持 MPS/多 GPU',
|
|
93
|
+
},
|
|
94
|
+
{
|
|
95
|
+
ruleId: 'ml-missing-no-grad',
|
|
96
|
+
category: 'performance',
|
|
97
|
+
dimension: 'file',
|
|
98
|
+
severity: 'warning',
|
|
99
|
+
languages: ['python'],
|
|
100
|
+
pattern: /def\s+(?:evaluate|validate|test|predict|inference)\s*\([^)]*\)[\s\S]*?(?!torch\.no_grad|@torch\.no_grad)model\s*\(/,
|
|
101
|
+
message: '推理/评估函数应使用 @torch.no_grad() 或 with torch.no_grad() — 减少内存消耗并加速',
|
|
102
|
+
},
|
|
103
|
+
{
|
|
104
|
+
ruleId: 'ml-gradient-accumulation-zero',
|
|
105
|
+
category: 'correctness',
|
|
106
|
+
dimension: 'file',
|
|
107
|
+
severity: 'info',
|
|
108
|
+
languages: ['python'],
|
|
109
|
+
pattern: /loss\.backward\(\)[\s\S]*?optimizer\.step\(\)[\s\S]*?(?!optimizer\.zero_grad)/,
|
|
110
|
+
message: 'optimizer.step() 后应调用 optimizer.zero_grad() — 否则梯度会累积',
|
|
111
|
+
},
|
|
112
|
+
{
|
|
113
|
+
ruleId: 'ml-random-seed',
|
|
114
|
+
category: 'correctness',
|
|
115
|
+
dimension: 'file',
|
|
116
|
+
severity: 'info',
|
|
117
|
+
languages: ['python'],
|
|
118
|
+
pattern: /torch\.manual_seed|random\.seed|np\.random\.seed/,
|
|
119
|
+
message: '设置随机种子时建议同时设置 torch.manual_seed / torch.cuda.manual_seed_all / np.random.seed / random.seed 保证完全可复现',
|
|
120
|
+
},
|
|
121
|
+
];
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
detectPatterns(astSummary) {
|
|
125
|
+
const patterns = [];
|
|
126
|
+
|
|
127
|
+
// ── nn.Module subclasses ──
|
|
128
|
+
for (const cls of astSummary.classes || []) {
|
|
129
|
+
if (cls.superclass && /Module$|nn\.Module/.test(cls.superclass)) {
|
|
130
|
+
patterns.push({
|
|
131
|
+
type: 'pytorch-model',
|
|
132
|
+
className: cls.name,
|
|
133
|
+
line: cls.line,
|
|
134
|
+
confidence: 0.95,
|
|
135
|
+
});
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
// ── Dataset subclasses ──
|
|
140
|
+
for (const cls of astSummary.classes || []) {
|
|
141
|
+
if (cls.superclass && /Dataset$|IterableDataset/.test(cls.superclass)) {
|
|
142
|
+
patterns.push({
|
|
143
|
+
type: 'pytorch-dataset',
|
|
144
|
+
className: cls.name,
|
|
145
|
+
line: cls.line,
|
|
146
|
+
confidence: 0.9,
|
|
147
|
+
});
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// ── HuggingFace model/tokenizer/trainer ──
|
|
152
|
+
for (const cls of astSummary.classes || []) {
|
|
153
|
+
if (cls.superclass && /PreTrainedModel|PretrainedConfig|Trainer/.test(cls.superclass)) {
|
|
154
|
+
patterns.push({
|
|
155
|
+
type: 'huggingface-model',
|
|
156
|
+
className: cls.name,
|
|
157
|
+
line: cls.line,
|
|
158
|
+
confidence: 0.9,
|
|
159
|
+
});
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
// ── Training functions (train/train_one_epoch/train_step) ──
|
|
164
|
+
for (const m of astSummary.methods || []) {
|
|
165
|
+
const nameLower = m.name?.toLowerCase() || '';
|
|
166
|
+
if (
|
|
167
|
+
nameLower === 'train' ||
|
|
168
|
+
nameLower === 'train_one_epoch' ||
|
|
169
|
+
nameLower === 'train_step' ||
|
|
170
|
+
nameLower === 'training_step' ||
|
|
171
|
+
nameLower === 'train_loop'
|
|
172
|
+
) {
|
|
173
|
+
patterns.push({
|
|
174
|
+
type: 'ml-training-function',
|
|
175
|
+
methodName: m.name,
|
|
176
|
+
line: m.line,
|
|
177
|
+
confidence: 0.85,
|
|
178
|
+
});
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// ── Evaluation functions ──
|
|
183
|
+
for (const m of astSummary.methods || []) {
|
|
184
|
+
const nameLower = m.name?.toLowerCase() || '';
|
|
185
|
+
if (
|
|
186
|
+
nameLower === 'evaluate' ||
|
|
187
|
+
nameLower === 'validate' ||
|
|
188
|
+
nameLower === 'eval_step' ||
|
|
189
|
+
nameLower === 'validation_step' ||
|
|
190
|
+
nameLower === 'test_step'
|
|
191
|
+
) {
|
|
192
|
+
patterns.push({
|
|
193
|
+
type: 'ml-evaluation-function',
|
|
194
|
+
methodName: m.name,
|
|
195
|
+
line: m.line,
|
|
196
|
+
confidence: 0.85,
|
|
197
|
+
});
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
// ── forward method (nn.Module) ──
|
|
202
|
+
for (const m of astSummary.methods || []) {
|
|
203
|
+
if (m.name === 'forward' && m.className) {
|
|
204
|
+
patterns.push({
|
|
205
|
+
type: 'pytorch-forward',
|
|
206
|
+
className: m.className,
|
|
207
|
+
methodName: m.name,
|
|
208
|
+
line: m.line,
|
|
209
|
+
confidence: 0.9,
|
|
210
|
+
});
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
// ── Lightning modules ──
|
|
215
|
+
for (const cls of astSummary.classes || []) {
|
|
216
|
+
if (cls.superclass && /LightningModule|LightningDataModule/.test(cls.superclass)) {
|
|
217
|
+
patterns.push({
|
|
218
|
+
type: 'pytorch-lightning-module',
|
|
219
|
+
className: cls.name,
|
|
220
|
+
line: cls.line,
|
|
221
|
+
confidence: 0.9,
|
|
222
|
+
});
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
// ── Loss function classes ──
|
|
227
|
+
for (const cls of astSummary.classes || []) {
|
|
228
|
+
const nameLower = cls.name?.toLowerCase() || '';
|
|
229
|
+
if (nameLower.includes('loss') && cls.superclass && /Module/.test(cls.superclass)) {
|
|
230
|
+
patterns.push({
|
|
231
|
+
type: 'ml-custom-loss',
|
|
232
|
+
className: cls.name,
|
|
233
|
+
line: cls.line,
|
|
234
|
+
confidence: 0.85,
|
|
235
|
+
});
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
// ── ML ecosystem imports ──
|
|
240
|
+
const mlImports = (astSummary.imports || []).filter(
|
|
241
|
+
(imp) =>
|
|
242
|
+
imp.includes('torch') ||
|
|
243
|
+
imp.includes('tensorflow') ||
|
|
244
|
+
imp.includes('transformers') ||
|
|
245
|
+
imp.includes('datasets') ||
|
|
246
|
+
imp.includes('accelerate') ||
|
|
247
|
+
imp.includes('lightning') ||
|
|
248
|
+
imp.includes('sklearn') ||
|
|
249
|
+
imp.includes('numpy') ||
|
|
250
|
+
imp.includes('wandb') ||
|
|
251
|
+
imp.includes('tensorboard')
|
|
252
|
+
);
|
|
253
|
+
if (mlImports.length > 0) {
|
|
254
|
+
patterns.push({
|
|
255
|
+
type: 'ml-ecosystem-usage',
|
|
256
|
+
importCount: mlImports.length,
|
|
257
|
+
confidence: 0.85,
|
|
258
|
+
});
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
return patterns;
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
export const pack = new MLEnhancement();
|