jarvis-ai-assistant 0.1.85__py3-none-any.whl → 0.1.87__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/__init__.py +1 -1
- jarvis/agent.py +8 -7
- jarvis/jarvis_codebase/main.py +115 -100
- jarvis/jarvis_coder/main.py +104 -36
- jarvis/jarvis_rag/main.py +191 -148
- jarvis/models/ai8.py +1 -1
- jarvis/models/ollama.py +150 -0
- jarvis/models/openai.py +2 -2
- jarvis/models/oyi.py +3 -3
- jarvis/tools/ask_user.py +13 -15
- jarvis/tools/registry.py +10 -8
- jarvis/utils.py +48 -2
- {jarvis_ai_assistant-0.1.85.dist-info → jarvis_ai_assistant-0.1.87.dist-info}/METADATA +2 -4
- {jarvis_ai_assistant-0.1.85.dist-info → jarvis_ai_assistant-0.1.87.dist-info}/RECORD +18 -17
- {jarvis_ai_assistant-0.1.85.dist-info → jarvis_ai_assistant-0.1.87.dist-info}/LICENSE +0 -0
- {jarvis_ai_assistant-0.1.85.dist-info → jarvis_ai_assistant-0.1.87.dist-info}/WHEEL +0 -0
- {jarvis_ai_assistant-0.1.85.dist-info → jarvis_ai_assistant-0.1.87.dist-info}/entry_points.txt +0 -0
- {jarvis_ai_assistant-0.1.85.dist-info → jarvis_ai_assistant-0.1.87.dist-info}/top_level.txt +0 -0
jarvis/jarvis_rag/main.py
CHANGED
|
@@ -5,7 +5,7 @@ import faiss
|
|
|
5
5
|
from typing import List, Tuple, Optional, Dict
|
|
6
6
|
from sentence_transformers import SentenceTransformer
|
|
7
7
|
import pickle
|
|
8
|
-
from jarvis.utils import OutputType, PrettyOutput, find_git_root, load_embedding_model
|
|
8
|
+
from jarvis.utils import OutputType, PrettyOutput, find_git_root, get_max_context_length, load_embedding_model, load_rerank_model
|
|
9
9
|
from jarvis.utils import load_env_from_file
|
|
10
10
|
import tiktoken
|
|
11
11
|
from dataclasses import dataclass
|
|
@@ -14,6 +14,8 @@ import fitz # PyMuPDF for PDF files
|
|
|
14
14
|
from docx import Document as DocxDocument # python-docx for DOCX files
|
|
15
15
|
from pathlib import Path
|
|
16
16
|
from jarvis.models.registry import PlatformRegistry
|
|
17
|
+
import shutil
|
|
18
|
+
from datetime import datetime
|
|
17
19
|
|
|
18
20
|
@dataclass
|
|
19
21
|
class Document:
|
|
@@ -161,7 +163,7 @@ class RAGTool:
|
|
|
161
163
|
self.cache_path = os.path.join(self.data_dir, "cache.pkl")
|
|
162
164
|
self.documents: List[Document] = []
|
|
163
165
|
self.index = None
|
|
164
|
-
self.max_context_length =
|
|
166
|
+
self.max_context_length = get_max_context_length()
|
|
165
167
|
|
|
166
168
|
# 加载缓存
|
|
167
169
|
self._load_cache()
|
|
@@ -193,15 +195,30 @@ class RAGTool:
|
|
|
193
195
|
self.index = None
|
|
194
196
|
|
|
195
197
|
def _save_cache(self, vectors: np.ndarray):
|
|
196
|
-
"""
|
|
198
|
+
"""优化缓存保存"""
|
|
197
199
|
try:
|
|
200
|
+
# 添加版本号和时间戳
|
|
198
201
|
cache_data = {
|
|
202
|
+
"version": "1.0",
|
|
203
|
+
"timestamp": datetime.now().isoformat(),
|
|
199
204
|
"documents": self.documents,
|
|
200
|
-
"vectors": vectors
|
|
205
|
+
"vectors": vectors,
|
|
206
|
+
"metadata": {
|
|
207
|
+
"vector_dim": self.vector_dim,
|
|
208
|
+
"total_docs": len(self.documents),
|
|
209
|
+
"model_name": self.embedding_model.__class__.__name__
|
|
210
|
+
}
|
|
201
211
|
}
|
|
212
|
+
|
|
213
|
+
# 使用压缩存储
|
|
202
214
|
with open(self.cache_path, 'wb') as f:
|
|
203
|
-
pickle.dump(cache_data, f)
|
|
204
|
-
|
|
215
|
+
pickle.dump(cache_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
216
|
+
|
|
217
|
+
# 创建备份
|
|
218
|
+
backup_path = f"{self.cache_path}.backup"
|
|
219
|
+
shutil.copy2(self.cache_path, backup_path)
|
|
220
|
+
|
|
221
|
+
PrettyOutput.print(f"缓存已保存: {len(self.documents)} 个文档片段",
|
|
205
222
|
output_type=OutputType.INFO)
|
|
206
223
|
except Exception as e:
|
|
207
224
|
PrettyOutput.print(f"保存缓存失败: {str(e)}",
|
|
@@ -209,100 +226,74 @@ class RAGTool:
|
|
|
209
226
|
|
|
210
227
|
def _build_index(self, vectors: np.ndarray):
|
|
211
228
|
"""构建FAISS索引"""
|
|
212
|
-
#
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
# 用IndexIDMap包装HNSW索引
|
|
218
|
-
self.index = faiss.IndexIDMap(hnsw_index)
|
|
229
|
+
# 添加IVF索引以提高大规模检索性能
|
|
230
|
+
nlist = max(4, int(vectors.shape[0] / 1000)) # 每1000个向量一个聚类中心
|
|
231
|
+
quantizer = faiss.IndexFlatIP(self.vector_dim)
|
|
232
|
+
self.index = faiss.IndexIVFFlat(quantizer, self.vector_dim, nlist, faiss.METRIC_INNER_PRODUCT)
|
|
219
233
|
|
|
220
|
-
# 添加向量到索引
|
|
221
234
|
if vectors.shape[0] > 0:
|
|
222
|
-
|
|
235
|
+
# 训练IVF索引
|
|
236
|
+
self.index.train(vectors)
|
|
237
|
+
self.index.add(vectors)
|
|
238
|
+
# 设置搜索时探测的聚类数
|
|
239
|
+
self.index.nprobe = min(nlist, 10)
|
|
223
240
|
else:
|
|
224
241
|
self.index = None
|
|
225
242
|
|
|
226
243
|
def _split_text(self, text: str) -> List[str]:
|
|
227
|
-
"""
|
|
244
|
+
"""使用更智能的分块策略"""
|
|
245
|
+
# 添加重叠分块以保持上下文连贯性
|
|
246
|
+
overlap_size = min(200, self.max_paragraph_length // 4)
|
|
228
247
|
|
|
229
|
-
Args:
|
|
230
|
-
text: 要分割的文本
|
|
231
|
-
|
|
232
|
-
Returns:
|
|
233
|
-
分割后的段落列表
|
|
234
|
-
"""
|
|
235
|
-
# 首先按空行分割
|
|
236
248
|
paragraphs = []
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
for line in text.split('\n'):
|
|
240
|
-
line = line.strip()
|
|
241
|
-
if not line: # 空行表示段落结束
|
|
242
|
-
if current_paragraph:
|
|
243
|
-
paragraph_text = ' '.join(current_paragraph)
|
|
244
|
-
if len(paragraph_text) >= self.min_paragraph_length:
|
|
245
|
-
paragraphs.append(paragraph_text)
|
|
246
|
-
current_paragraph = []
|
|
247
|
-
else:
|
|
248
|
-
current_paragraph.append(line)
|
|
249
|
+
current_chunk = []
|
|
250
|
+
current_length = 0
|
|
249
251
|
|
|
250
|
-
#
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
paragraphs.append(paragraph_text)
|
|
252
|
+
# 首先按句子分割
|
|
253
|
+
sentences = []
|
|
254
|
+
current_sentence = []
|
|
255
|
+
sentence_ends = {'。', '!', '?', '…', '.', '!', '?'}
|
|
255
256
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
# 按句子分割过长的段落
|
|
263
|
-
sentences = []
|
|
257
|
+
for char in text:
|
|
258
|
+
current_sentence.append(char)
|
|
259
|
+
if char in sentence_ends:
|
|
260
|
+
sentence = ''.join(current_sentence)
|
|
261
|
+
if sentence.strip():
|
|
262
|
+
sentences.append(sentence)
|
|
264
263
|
current_sentence = []
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
sentences.append(sentence)
|
|
275
|
-
current_sentence = []
|
|
276
|
-
|
|
277
|
-
# 处理最后一个句子
|
|
278
|
-
if current_sentence:
|
|
279
|
-
sentence = ''.join(current_sentence)
|
|
280
|
-
if sentence.strip():
|
|
281
|
-
sentences.append(sentence)
|
|
282
|
-
|
|
283
|
-
# 组合句子成适当长度的段落
|
|
284
|
-
current_chunk = []
|
|
285
|
-
current_length = 0
|
|
286
|
-
|
|
287
|
-
for sentence in sentences:
|
|
288
|
-
sentence_length = len(sentence)
|
|
289
|
-
if current_length + sentence_length > self.max_paragraph_length:
|
|
290
|
-
if current_chunk:
|
|
291
|
-
final_paragraphs.append(''.join(current_chunk))
|
|
292
|
-
current_chunk = [sentence]
|
|
293
|
-
current_length = sentence_length
|
|
294
|
-
else:
|
|
295
|
-
current_chunk.append(sentence)
|
|
296
|
-
current_length += sentence_length
|
|
297
|
-
|
|
298
|
-
# 处理最后一个chunk
|
|
264
|
+
|
|
265
|
+
if current_sentence:
|
|
266
|
+
sentence = ''.join(current_sentence)
|
|
267
|
+
if sentence.strip():
|
|
268
|
+
sentences.append(sentence)
|
|
269
|
+
|
|
270
|
+
# 基于句子构建重叠块
|
|
271
|
+
for sentence in sentences:
|
|
272
|
+
if current_length + len(sentence) > self.max_paragraph_length:
|
|
299
273
|
if current_chunk:
|
|
300
|
-
|
|
274
|
+
chunk_text = ' '.join(current_chunk)
|
|
275
|
+
if len(chunk_text) >= self.min_paragraph_length:
|
|
276
|
+
paragraphs.append(chunk_text)
|
|
277
|
+
|
|
278
|
+
# 保留部分内容作为重叠
|
|
279
|
+
overlap_text = ' '.join(current_chunk[-2:]) # 保留最后两句
|
|
280
|
+
current_chunk = []
|
|
281
|
+
if overlap_text:
|
|
282
|
+
current_chunk.append(overlap_text)
|
|
283
|
+
current_length = len(overlap_text)
|
|
284
|
+
else:
|
|
285
|
+
current_length = 0
|
|
286
|
+
|
|
287
|
+
current_chunk.append(sentence)
|
|
288
|
+
current_length += len(sentence)
|
|
301
289
|
|
|
302
|
-
#
|
|
303
|
-
|
|
290
|
+
# 处理最后一个chunk
|
|
291
|
+
if current_chunk:
|
|
292
|
+
chunk_text = ' '.join(current_chunk)
|
|
293
|
+
if len(chunk_text) >= self.min_paragraph_length:
|
|
294
|
+
paragraphs.append(chunk_text)
|
|
304
295
|
|
|
305
|
-
return
|
|
296
|
+
return paragraphs
|
|
306
297
|
|
|
307
298
|
def _get_embedding(self, text: str) -> np.ndarray:
|
|
308
299
|
"""获取文本的向量表示"""
|
|
@@ -410,82 +401,131 @@ class RAGTool:
|
|
|
410
401
|
output_type=OutputType.SUCCESS)
|
|
411
402
|
|
|
412
403
|
def search(self, query: str, top_k: int = 5) -> List[Tuple[Document, float]]:
|
|
413
|
-
"""
|
|
414
|
-
|
|
415
|
-
Args:
|
|
416
|
-
query: 查询文本
|
|
417
|
-
top_k: 返回结果数量
|
|
418
|
-
|
|
419
|
-
Returns:
|
|
420
|
-
文档和相似度得分的列表
|
|
421
|
-
"""
|
|
404
|
+
"""优化搜索策略"""
|
|
422
405
|
if not self.index:
|
|
423
406
|
PrettyOutput.print("索引未构建,正在构建...", output_type=OutputType.INFO)
|
|
424
407
|
self.build_index(self.root_dir)
|
|
408
|
+
|
|
409
|
+
# 实现MMR (Maximal Marginal Relevance) 来增加结果多样性
|
|
410
|
+
def mmr(query_vec, doc_vecs, doc_ids, lambda_param=0.5, n_docs=top_k):
|
|
411
|
+
selected = []
|
|
412
|
+
selected_ids = []
|
|
413
|
+
|
|
414
|
+
while len(selected) < n_docs and len(doc_ids) > 0:
|
|
415
|
+
best_score = -1
|
|
416
|
+
best_idx = -1
|
|
417
|
+
|
|
418
|
+
for i, (doc_vec, doc_id) in enumerate(zip(doc_vecs, doc_ids)):
|
|
419
|
+
# 计算与查询的相似度
|
|
420
|
+
query_sim = float(np.dot(query_vec, doc_vec))
|
|
421
|
+
|
|
422
|
+
# 计算与已选文档的最大相似度
|
|
423
|
+
if selected:
|
|
424
|
+
doc_sims = [float(np.dot(doc_vec, selected_doc)) for selected_doc in selected]
|
|
425
|
+
max_doc_sim = max(doc_sims)
|
|
426
|
+
else:
|
|
427
|
+
max_doc_sim = 0
|
|
428
|
+
|
|
429
|
+
# MMR score
|
|
430
|
+
score = lambda_param * query_sim - (1 - lambda_param) * max_doc_sim
|
|
431
|
+
|
|
432
|
+
if score > best_score:
|
|
433
|
+
best_score = score
|
|
434
|
+
best_idx = i
|
|
435
|
+
|
|
436
|
+
if best_idx == -1:
|
|
437
|
+
break
|
|
438
|
+
|
|
439
|
+
selected.append(doc_vecs[best_idx])
|
|
440
|
+
selected_ids.append(doc_ids[best_idx])
|
|
441
|
+
doc_vecs = np.delete(doc_vecs, best_idx, axis=0)
|
|
442
|
+
doc_ids = np.delete(doc_ids, best_idx)
|
|
425
443
|
|
|
426
|
-
|
|
444
|
+
return selected_ids
|
|
445
|
+
|
|
446
|
+
# 获取查询向量
|
|
427
447
|
query_vector = self._get_embedding(query)
|
|
428
448
|
query_vector = query_vector.reshape(1, -1)
|
|
429
449
|
|
|
430
|
-
#
|
|
431
|
-
|
|
450
|
+
# 初始搜索更多结果用于MMR
|
|
451
|
+
initial_k = min(top_k * 2, len(self.documents))
|
|
452
|
+
distances, indices = self.index.search(query_vector, initial_k)
|
|
453
|
+
|
|
454
|
+
# 获取有效结果
|
|
455
|
+
valid_indices = indices[0][indices[0] != -1]
|
|
456
|
+
valid_vectors = np.vstack([self._get_embedding(self.documents[idx].content) for idx in valid_indices])
|
|
432
457
|
|
|
433
|
-
#
|
|
458
|
+
# 应用MMR
|
|
459
|
+
final_indices = mmr(query_vector[0], valid_vectors, valid_indices, n_docs=top_k)
|
|
460
|
+
|
|
461
|
+
# 构建结果
|
|
434
462
|
results = []
|
|
435
|
-
|
|
463
|
+
for idx in final_indices:
|
|
464
|
+
doc = self.documents[idx]
|
|
465
|
+
similarity = 1.0 / (1.0 + float(distances[0][np.where(indices[0] == idx)[0][0]]))
|
|
466
|
+
results.append((doc, similarity))
|
|
436
467
|
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
468
|
+
return results
|
|
469
|
+
|
|
470
|
+
def _rerank_results(self, query: str, initial_results: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
|
|
471
|
+
"""使用 rerank 模型重新排序搜索结果"""
|
|
472
|
+
try:
|
|
473
|
+
import torch
|
|
474
|
+
model, tokenizer = load_rerank_model()
|
|
475
|
+
|
|
476
|
+
# 准备数据
|
|
477
|
+
pairs = []
|
|
478
|
+
for doc, _ in initial_results:
|
|
479
|
+
# 组合文档信息
|
|
480
|
+
doc_content = f"""
|
|
481
|
+
文件: {doc.metadata['file_path']}
|
|
482
|
+
内容: {doc.content}
|
|
483
|
+
"""
|
|
484
|
+
pairs.append([query, doc_content])
|
|
440
485
|
|
|
441
|
-
|
|
442
|
-
|
|
486
|
+
# 对每个文档对进行打分
|
|
487
|
+
scores = []
|
|
488
|
+
batch_size = 8
|
|
443
489
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
490
|
+
with torch.no_grad():
|
|
491
|
+
for i in range(0, len(pairs), batch_size):
|
|
492
|
+
batch_pairs = pairs[i:i + batch_size]
|
|
493
|
+
encoded = tokenizer(
|
|
494
|
+
batch_pairs,
|
|
495
|
+
padding=True,
|
|
496
|
+
truncation=True,
|
|
497
|
+
max_length=512,
|
|
498
|
+
return_tensors='pt'
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
if torch.cuda.is_available():
|
|
502
|
+
encoded = {k: v.cuda() for k, v in encoded.items()}
|
|
503
|
+
|
|
504
|
+
outputs = model(**encoded)
|
|
505
|
+
batch_scores = outputs.logits.squeeze(-1).cpu().numpy()
|
|
506
|
+
scores.extend(batch_scores.tolist())
|
|
447
507
|
|
|
448
|
-
#
|
|
449
|
-
|
|
508
|
+
# 归一化分数到 0-1 范围
|
|
509
|
+
if scores:
|
|
510
|
+
min_score = min(scores)
|
|
511
|
+
max_score = max(scores)
|
|
512
|
+
if max_score > min_score:
|
|
513
|
+
scores = [(s - min_score) / (max_score - min_score) for s in scores]
|
|
450
514
|
|
|
451
|
-
#
|
|
452
|
-
|
|
453
|
-
for
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
content_parts.extend(file_docs[i].content for i in range(start_idx, current_idx))
|
|
460
|
-
content_parts.append(doc.content)
|
|
461
|
-
content_parts.extend(file_docs[i].content for i in range(current_idx + 1, end_idx))
|
|
462
|
-
|
|
463
|
-
merged_content = "\n".join(content_parts)
|
|
464
|
-
|
|
465
|
-
# 创建文档对象
|
|
466
|
-
context_doc = Document(
|
|
467
|
-
content=merged_content,
|
|
468
|
-
metadata={
|
|
469
|
-
**doc.metadata,
|
|
470
|
-
"similarity": similarity
|
|
471
|
-
}
|
|
472
|
-
)
|
|
473
|
-
|
|
474
|
-
# 计算添加这个结果后的总长度
|
|
475
|
-
total_content_length = len(merged_content)
|
|
476
|
-
|
|
477
|
-
# 检查是否在长度限制内
|
|
478
|
-
if current_length + total_content_length <= self.max_context_length:
|
|
479
|
-
results.append((context_doc, similarity))
|
|
480
|
-
current_length += total_content_length
|
|
481
|
-
added = True
|
|
482
|
-
break
|
|
515
|
+
# 将分数与文档组合并排序
|
|
516
|
+
scored_results = []
|
|
517
|
+
for (doc, _), score in zip(initial_results, scores):
|
|
518
|
+
if score >= 0.5: # 只保留关联度大于 0.5 的结果
|
|
519
|
+
scored_results.append((doc, float(score)))
|
|
520
|
+
|
|
521
|
+
# 按分数降序排序
|
|
522
|
+
scored_results.sort(key=lambda x: x[1], reverse=True)
|
|
483
523
|
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
524
|
+
return scored_results
|
|
525
|
+
|
|
526
|
+
except Exception as e:
|
|
527
|
+
PrettyOutput.print(f"重排序失败,使用原始排序: {str(e)}", output_type=OutputType.WARNING)
|
|
528
|
+
return initial_results
|
|
489
529
|
|
|
490
530
|
def is_index_built(self):
|
|
491
531
|
"""检查索引是否已构建"""
|
|
@@ -567,9 +607,12 @@ def main():
|
|
|
567
607
|
args = parser.parse_args()
|
|
568
608
|
|
|
569
609
|
try:
|
|
570
|
-
current_dir =
|
|
610
|
+
current_dir = os.getcwd()
|
|
571
611
|
rag = RAGTool(current_dir)
|
|
572
612
|
|
|
613
|
+
if not args.dir:
|
|
614
|
+
args.dir = current_dir
|
|
615
|
+
|
|
573
616
|
if args.dir and args.build:
|
|
574
617
|
PrettyOutput.print(f"正在处理目录: {args.dir}", output_type=OutputType.INFO)
|
|
575
618
|
rag.build_index(args.dir)
|
jarvis/models/ai8.py
CHANGED
jarvis/models/ollama.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
from typing import List, Dict
|
|
3
|
+
from jarvis.models.base import BasePlatform
|
|
4
|
+
from jarvis.utils import OutputType, PrettyOutput
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
|
|
8
|
+
class OllamaPlatform(BasePlatform):
|
|
9
|
+
"""Ollama 平台实现"""
|
|
10
|
+
|
|
11
|
+
platform_name = "ollama"
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
"""初始化模型"""
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
# 检查环境变量并提供帮助信息
|
|
18
|
+
self.api_base = os.getenv("OLLAMA_API_BASE", "http://localhost:11434")
|
|
19
|
+
self.model_name = os.getenv("JARVIS_MODEL") or "deepseek-r1:1.5b"
|
|
20
|
+
|
|
21
|
+
# 检查 Ollama 服务是否可用
|
|
22
|
+
try:
|
|
23
|
+
PrettyOutput.print(f"正在连接 Ollama 服务 ({self.api_base})...", OutputType.INFO)
|
|
24
|
+
response = requests.get(f"{self.api_base}/api/tags")
|
|
25
|
+
response.raise_for_status()
|
|
26
|
+
available_models = [model["name"] for model in response.json().get("models", [])]
|
|
27
|
+
|
|
28
|
+
if not available_models:
|
|
29
|
+
PrettyOutput.print("\n需要先下载 Ollama 模型才能使用:", OutputType.INFO)
|
|
30
|
+
PrettyOutput.print("1. 安装 Ollama: https://ollama.ai", OutputType.INFO)
|
|
31
|
+
PrettyOutput.print("2. 下载模型:", OutputType.INFO)
|
|
32
|
+
PrettyOutput.print(f" ollama pull {self.model_name}", OutputType.INFO)
|
|
33
|
+
raise Exception("No available models found")
|
|
34
|
+
|
|
35
|
+
PrettyOutput.print(f"可用模型: {', '.join(available_models)}", OutputType.INFO)
|
|
36
|
+
|
|
37
|
+
if self.model_name not in available_models:
|
|
38
|
+
PrettyOutput.print(f"\n警告:模型 {self.model_name} 未下载", OutputType.WARNING)
|
|
39
|
+
PrettyOutput.print("\n请使用以下命令下载模型:", OutputType.INFO)
|
|
40
|
+
PrettyOutput.print(f"ollama pull {self.model_name}", OutputType.INFO)
|
|
41
|
+
raise Exception(f"Model {self.model_name} is not available")
|
|
42
|
+
|
|
43
|
+
PrettyOutput.print(f"使用模型: {self.model_name}", OutputType.SUCCESS)
|
|
44
|
+
|
|
45
|
+
except requests.exceptions.ConnectionError:
|
|
46
|
+
PrettyOutput.print("\nOllama 服务未启动或无法连接", OutputType.ERROR)
|
|
47
|
+
PrettyOutput.print("请确保已经:", OutputType.INFO)
|
|
48
|
+
PrettyOutput.print("1. 安装了 Ollama: https://ollama.ai", OutputType.INFO)
|
|
49
|
+
PrettyOutput.print("2. 启动了 Ollama 服务", OutputType.INFO)
|
|
50
|
+
PrettyOutput.print("3. 服务地址配置正确 (默认: http://localhost:11434)", OutputType.INFO)
|
|
51
|
+
raise Exception("Ollama service is not available")
|
|
52
|
+
|
|
53
|
+
self.messages = []
|
|
54
|
+
self.system_message = ""
|
|
55
|
+
|
|
56
|
+
def set_model_name(self, model_name: str):
|
|
57
|
+
"""设置模型名称"""
|
|
58
|
+
self.model_name = model_name
|
|
59
|
+
|
|
60
|
+
def chat(self, message: str) -> str:
|
|
61
|
+
"""执行对话"""
|
|
62
|
+
try:
|
|
63
|
+
# 构建消息列表
|
|
64
|
+
messages = []
|
|
65
|
+
if self.system_message:
|
|
66
|
+
messages.append({"role": "system", "content": self.system_message})
|
|
67
|
+
messages.extend(self.messages)
|
|
68
|
+
messages.append({"role": "user", "content": message})
|
|
69
|
+
|
|
70
|
+
# 构建请求数据
|
|
71
|
+
data = {
|
|
72
|
+
"model": self.model_name,
|
|
73
|
+
"messages": messages,
|
|
74
|
+
"stream": True # 启用流式输出
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# 发送请求
|
|
78
|
+
response = requests.post(
|
|
79
|
+
f"{self.api_base}/api/chat",
|
|
80
|
+
json=data,
|
|
81
|
+
stream=True
|
|
82
|
+
)
|
|
83
|
+
response.raise_for_status()
|
|
84
|
+
|
|
85
|
+
# 处理流式响应
|
|
86
|
+
full_response = ""
|
|
87
|
+
for line in response.iter_lines():
|
|
88
|
+
if line:
|
|
89
|
+
chunk = line.decode()
|
|
90
|
+
try:
|
|
91
|
+
result = json.loads(chunk)
|
|
92
|
+
if "message" in result and "content" in result["message"]:
|
|
93
|
+
text = result["message"]["content"]
|
|
94
|
+
if not self.suppress_output:
|
|
95
|
+
PrettyOutput.print_stream(text)
|
|
96
|
+
full_response += text
|
|
97
|
+
except json.JSONDecodeError:
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
if not self.suppress_output:
|
|
101
|
+
PrettyOutput.print_stream_end()
|
|
102
|
+
|
|
103
|
+
# 更新消息历史
|
|
104
|
+
self.messages.append({"role": "user", "content": message})
|
|
105
|
+
self.messages.append({"role": "assistant", "content": full_response})
|
|
106
|
+
|
|
107
|
+
return full_response
|
|
108
|
+
|
|
109
|
+
except Exception as e:
|
|
110
|
+
PrettyOutput.print(f"对话失败: {str(e)}", OutputType.ERROR)
|
|
111
|
+
raise Exception(f"Chat failed: {str(e)}")
|
|
112
|
+
|
|
113
|
+
def upload_files(self, file_list: List[str]) -> List[Dict]:
|
|
114
|
+
"""上传文件 (Ollama 不支持文件上传)"""
|
|
115
|
+
PrettyOutput.print("Ollama 不支持文件上传", output_type=OutputType.WARNING)
|
|
116
|
+
return []
|
|
117
|
+
|
|
118
|
+
def reset(self):
|
|
119
|
+
"""重置模型状态"""
|
|
120
|
+
self.messages = []
|
|
121
|
+
if self.system_message:
|
|
122
|
+
self.messages.append({"role": "system", "content": self.system_message})
|
|
123
|
+
|
|
124
|
+
def name(self) -> str:
|
|
125
|
+
"""返回模型名称"""
|
|
126
|
+
return self.model_name
|
|
127
|
+
|
|
128
|
+
def delete_chat(self) -> bool:
|
|
129
|
+
"""删除当前聊天会话"""
|
|
130
|
+
self.reset()
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
def set_system_message(self, message: str):
|
|
134
|
+
"""设置系统消息"""
|
|
135
|
+
self.system_message = message
|
|
136
|
+
self.reset() # 重置会话以应用新的系统消息
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
if __name__ == "__main__":
|
|
140
|
+
try:
|
|
141
|
+
ollama = OllamaPlatform()
|
|
142
|
+
while True:
|
|
143
|
+
try:
|
|
144
|
+
message = input("\n输入问题(Ctrl+C退出): ")
|
|
145
|
+
ollama.chat(message)
|
|
146
|
+
except KeyboardInterrupt:
|
|
147
|
+
print("\n再见!")
|
|
148
|
+
break
|
|
149
|
+
except Exception as e:
|
|
150
|
+
PrettyOutput.print(f"程序异常退出: {str(e)}", OutputType.ERROR)
|
jarvis/models/openai.py
CHANGED
|
@@ -30,8 +30,8 @@ class OpenAIModel(BasePlatform):
|
|
|
30
30
|
PrettyOutput.print(" export OPENAI_MODEL_NAME=your_model_name", OutputType.INFO)
|
|
31
31
|
raise Exception("OPENAI_API_KEY is not set")
|
|
32
32
|
|
|
33
|
-
self.base_url = os.getenv("OPENAI_API_BASE", "https://api.
|
|
34
|
-
self.model_name =
|
|
33
|
+
self.base_url = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
|
34
|
+
self.model_name = os.getenv("JARVIS_MODEL") or "gpt-4o"
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
self.client = OpenAI(
|
jarvis/models/oyi.py
CHANGED
|
@@ -2,7 +2,7 @@ import mimetypes
|
|
|
2
2
|
import os
|
|
3
3
|
from typing import Dict, List
|
|
4
4
|
from jarvis.models.base import BasePlatform
|
|
5
|
-
from jarvis.utils import PrettyOutput, OutputType
|
|
5
|
+
from jarvis.utils import PrettyOutput, OutputType, get_max_context_length
|
|
6
6
|
import requests
|
|
7
7
|
import json
|
|
8
8
|
|
|
@@ -72,10 +72,10 @@ class OyiModel(BasePlatform):
|
|
|
72
72
|
"is_webSearch": True,
|
|
73
73
|
"message": [],
|
|
74
74
|
"systemMessage": None,
|
|
75
|
-
"requestMsgCount":
|
|
75
|
+
"requestMsgCount": 65536,
|
|
76
76
|
"temperature": 0.8,
|
|
77
77
|
"speechVoice": "Alloy",
|
|
78
|
-
"max_tokens":
|
|
78
|
+
"max_tokens": get_max_context_length(),
|
|
79
79
|
"chatPluginIds": []
|
|
80
80
|
})
|
|
81
81
|
}
|
jarvis/tools/ask_user.py
CHANGED
|
@@ -2,11 +2,9 @@ from typing import Dict, Any
|
|
|
2
2
|
from jarvis.tools.base import Tool
|
|
3
3
|
from jarvis.utils import get_multiline_input, PrettyOutput, OutputType
|
|
4
4
|
|
|
5
|
-
class AskUserTool
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
name="ask_user",
|
|
9
|
-
description="""当缺少完成任务的信息或有关键决策信息缺失时,询问用户。
|
|
5
|
+
class AskUserTool:
|
|
6
|
+
name="ask_user",
|
|
7
|
+
description="""当缺少完成任务的信息或有关键决策信息缺失时,询问用户。
|
|
10
8
|
用户可以输入多行文本,空行结束输入。
|
|
11
9
|
|
|
12
10
|
使用场景:
|
|
@@ -17,17 +15,17 @@ class AskUserTool(Tool):
|
|
|
17
15
|
|
|
18
16
|
参数说明:
|
|
19
17
|
- question: 要询问用户的问题,应该清晰明确""",
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
}
|
|
27
|
-
},
|
|
28
|
-
"required": ["question"]
|
|
18
|
+
parameters={
|
|
19
|
+
"type": "object",
|
|
20
|
+
"properties": {
|
|
21
|
+
"question": {
|
|
22
|
+
"type": "string",
|
|
23
|
+
"description": "要询问用户的问题"
|
|
29
24
|
}
|
|
30
|
-
|
|
25
|
+
},
|
|
26
|
+
"required": ["question"]
|
|
27
|
+
}
|
|
28
|
+
|
|
31
29
|
|
|
32
30
|
def execute(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
33
31
|
"""执行询问用户操作
|