jarvis-ai-assistant 0.1.132__py3-none-any.whl → 0.1.138__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jarvis-ai-assistant might be problematic. Click here for more details.

Files changed (82) hide show
  1. jarvis/__init__.py +1 -1
  2. jarvis/jarvis_agent/__init__.py +330 -347
  3. jarvis/jarvis_agent/builtin_input_handler.py +16 -6
  4. jarvis/jarvis_agent/file_input_handler.py +9 -9
  5. jarvis/jarvis_agent/jarvis.py +143 -0
  6. jarvis/jarvis_agent/main.py +12 -13
  7. jarvis/jarvis_agent/output_handler.py +3 -3
  8. jarvis/jarvis_agent/patch.py +92 -64
  9. jarvis/jarvis_agent/shell_input_handler.py +5 -3
  10. jarvis/jarvis_code_agent/code_agent.py +263 -177
  11. jarvis/jarvis_code_agent/file_select.py +24 -24
  12. jarvis/jarvis_dev/main.py +45 -59
  13. jarvis/jarvis_git_details/__init__.py +0 -0
  14. jarvis/jarvis_git_details/main.py +179 -0
  15. jarvis/jarvis_git_squash/main.py +7 -7
  16. jarvis/jarvis_lsp/base.py +11 -53
  17. jarvis/jarvis_lsp/cpp.py +13 -28
  18. jarvis/jarvis_lsp/go.py +13 -28
  19. jarvis/jarvis_lsp/python.py +8 -27
  20. jarvis/jarvis_lsp/registry.py +21 -83
  21. jarvis/jarvis_lsp/rust.py +15 -30
  22. jarvis/jarvis_methodology/main.py +101 -0
  23. jarvis/jarvis_multi_agent/__init__.py +10 -51
  24. jarvis/jarvis_multi_agent/main.py +43 -0
  25. jarvis/jarvis_platform/__init__.py +1 -1
  26. jarvis/jarvis_platform/ai8.py +67 -89
  27. jarvis/jarvis_platform/base.py +14 -13
  28. jarvis/jarvis_platform/kimi.py +25 -28
  29. jarvis/jarvis_platform/ollama.py +24 -26
  30. jarvis/jarvis_platform/openai.py +15 -19
  31. jarvis/jarvis_platform/oyi.py +48 -50
  32. jarvis/jarvis_platform/registry.py +29 -44
  33. jarvis/jarvis_platform/yuanbao.py +39 -43
  34. jarvis/jarvis_platform_manager/main.py +81 -81
  35. jarvis/jarvis_platform_manager/openai_test.py +21 -21
  36. jarvis/jarvis_rag/file_processors.py +18 -18
  37. jarvis/jarvis_rag/main.py +262 -278
  38. jarvis/jarvis_smart_shell/main.py +12 -12
  39. jarvis/jarvis_tools/ask_codebase.py +85 -78
  40. jarvis/jarvis_tools/ask_user.py +8 -8
  41. jarvis/jarvis_tools/base.py +4 -4
  42. jarvis/jarvis_tools/chdir.py +9 -9
  43. jarvis/jarvis_tools/code_review.py +40 -21
  44. jarvis/jarvis_tools/create_code_agent.py +15 -15
  45. jarvis/jarvis_tools/create_sub_agent.py +0 -1
  46. jarvis/jarvis_tools/execute_python_script.py +3 -3
  47. jarvis/jarvis_tools/execute_shell.py +11 -11
  48. jarvis/jarvis_tools/execute_shell_script.py +3 -3
  49. jarvis/jarvis_tools/file_analyzer.py +116 -105
  50. jarvis/jarvis_tools/file_operation.py +22 -20
  51. jarvis/jarvis_tools/find_caller.py +105 -40
  52. jarvis/jarvis_tools/find_methodolopy.py +65 -0
  53. jarvis/jarvis_tools/find_symbol.py +123 -39
  54. jarvis/jarvis_tools/function_analyzer.py +140 -57
  55. jarvis/jarvis_tools/git_commiter.py +10 -10
  56. jarvis/jarvis_tools/lsp_get_diagnostics.py +19 -19
  57. jarvis/jarvis_tools/methodology.py +22 -67
  58. jarvis/jarvis_tools/project_analyzer.py +137 -53
  59. jarvis/jarvis_tools/rag.py +15 -20
  60. jarvis/jarvis_tools/read_code.py +25 -23
  61. jarvis/jarvis_tools/read_webpage.py +31 -31
  62. jarvis/jarvis_tools/registry.py +72 -52
  63. jarvis/jarvis_tools/search_web.py +23 -353
  64. jarvis/jarvis_tools/tool_generator.py +19 -19
  65. jarvis/jarvis_utils/config.py +36 -96
  66. jarvis/jarvis_utils/embedding.py +83 -83
  67. jarvis/jarvis_utils/git_utils.py +20 -20
  68. jarvis/jarvis_utils/globals.py +18 -6
  69. jarvis/jarvis_utils/input.py +10 -9
  70. jarvis/jarvis_utils/methodology.py +141 -140
  71. jarvis/jarvis_utils/output.py +13 -13
  72. jarvis/jarvis_utils/utils.py +23 -71
  73. {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/METADATA +6 -15
  74. jarvis_ai_assistant-0.1.138.dist-info/RECORD +85 -0
  75. {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/entry_points.txt +4 -3
  76. jarvis/jarvis_tools/lsp_find_definition.py +0 -150
  77. jarvis/jarvis_tools/lsp_find_references.py +0 -127
  78. jarvis/jarvis_tools/select_code_files.py +0 -62
  79. jarvis_ai_assistant-0.1.132.dist-info/RECORD +0 -82
  80. {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/LICENSE +0 -0
  81. {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/WHEEL +0 -0
  82. {jarvis_ai_assistant-0.1.132.dist-info → jarvis_ai_assistant-0.1.138.dist-info}/top_level.txt +0 -0
@@ -15,22 +15,22 @@ def list_platforms():
15
15
  """List all supported platforms and models"""
16
16
  registry = PlatformRegistry.get_global_platform_registry()
17
17
  platforms = registry.get_available_platforms()
18
-
18
+
19
19
  PrettyOutput.section("Supported platforms and models", OutputType.SUCCESS)
20
-
20
+
21
21
  for platform_name in platforms:
22
22
  # Create platform instance
23
23
  platform = registry.create_platform(platform_name)
24
24
  if not platform:
25
25
  continue
26
-
26
+
27
27
  # Get the list of models supported by the platform
28
28
  try:
29
29
  models = platform.get_model_list()
30
-
30
+
31
31
  # Print platform name
32
32
  PrettyOutput.section(f"{platform_name}", OutputType.SUCCESS)
33
-
33
+
34
34
  output = ""
35
35
  # Print model list
36
36
  if models:
@@ -42,65 +42,65 @@ def list_platforms():
42
42
  PrettyOutput.print(output, OutputType.SUCCESS, lang="markdown")
43
43
  else:
44
44
  PrettyOutput.print(" • 没有可用的模型信息", OutputType.WARNING)
45
-
45
+
46
46
  except Exception as e:
47
47
  PrettyOutput.print(f"获取 {platform_name} 的模型列表失败: {str(e)}", OutputType.WARNING)
48
48
 
49
49
  def chat_with_model(platform_name: str, model_name: str):
50
50
  """Chat with specified platform and model"""
51
51
  registry = PlatformRegistry.get_global_platform_registry()
52
-
52
+
53
53
  # Create platform instance
54
54
  platform = registry.create_platform(platform_name)
55
55
  if not platform:
56
56
  PrettyOutput.print(f"创建平台 {platform_name} 失败", OutputType.WARNING)
57
57
  return
58
-
58
+
59
59
  try:
60
60
  # Set model
61
61
  platform.set_model_name(model_name)
62
62
  platform.set_suppress_output(False)
63
63
  PrettyOutput.print(f"连接到 {platform_name} 平台 {model_name} 模型", OutputType.SUCCESS)
64
-
64
+
65
65
  # Start conversation loop
66
66
  while True:
67
67
  # Get user input
68
68
  user_input = get_multiline_input("")
69
-
69
+
70
70
  # Check if input is cancelled
71
71
  if user_input.strip() == "/bye":
72
72
  PrettyOutput.print("再见!", OutputType.SUCCESS)
73
73
  break
74
-
74
+
75
75
  # Check if input is empty
76
76
  if not user_input.strip():
77
77
  continue
78
-
78
+
79
79
  # Check if it is a clear session command
80
80
  if user_input.strip() == "/clear":
81
81
  try:
82
- platform.delete_chat()
82
+ platform.reset()
83
83
  platform.set_model_name(model_name) # Reinitialize session
84
84
  PrettyOutput.print("会话已清除", OutputType.SUCCESS)
85
85
  except Exception as e:
86
86
  PrettyOutput.print(f"清除会话失败: {str(e)}", OutputType.ERROR)
87
87
  continue
88
-
88
+
89
89
  try:
90
90
  # Send to model and get reply
91
91
  response = platform.chat_until_success(user_input)
92
92
  if not response:
93
93
  PrettyOutput.print("没有有效的回复", OutputType.WARNING)
94
-
94
+
95
95
  except Exception as e:
96
96
  PrettyOutput.print(f"聊天失败: {str(e)}", OutputType.ERROR)
97
-
97
+
98
98
  except Exception as e:
99
99
  PrettyOutput.print(f"初始化会话失败: {str(e)}", OutputType.ERROR)
100
100
  finally:
101
101
  # Clean up resources
102
102
  try:
103
- platform.delete_chat()
103
+ platform.reset()
104
104
  except:
105
105
  pass
106
106
 
@@ -160,18 +160,18 @@ def service_command(args):
160
160
  import json
161
161
  import os
162
162
  from datetime import datetime
163
-
163
+
164
164
  host = args.host
165
165
  port = args.port
166
166
  default_platform = args.platform
167
167
  default_model = args.model
168
-
168
+
169
169
  # Create logs directory if it doesn't exist
170
170
  logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
171
171
  os.makedirs(logs_dir, exist_ok=True)
172
-
172
+
173
173
  app = FastAPI(title="Jarvis API Server")
174
-
174
+
175
175
  # 添加 CORS 中间件
176
176
  app.add_middleware(
177
177
  CORSMiddleware,
@@ -180,27 +180,27 @@ def service_command(args):
180
180
  allow_methods=["*"], # 允许所有方法
181
181
  allow_headers=["*"], # 允许所有头
182
182
  )
183
-
183
+
184
184
  registry = PlatformRegistry.get_global_platform_registry()
185
-
185
+
186
186
  PrettyOutput.print(f"Starting Jarvis API server on {host}:{port}", OutputType.SUCCESS)
187
187
  PrettyOutput.print("This server provides an OpenAI-compatible API", OutputType.INFO)
188
-
188
+
189
189
  if default_platform and default_model:
190
190
  PrettyOutput.print(f"Default platform: {default_platform}, model: {default_model}", OutputType.INFO)
191
-
191
+
192
192
  PrettyOutput.print("Available platforms:", OutputType.INFO)
193
-
193
+
194
194
  # Print available platforms and models
195
195
  platforms = registry.get_available_platforms()
196
196
  list_platforms()
197
-
197
+
198
198
  # Platform and model cache
199
199
  platform_instances = {}
200
-
200
+
201
201
  # Chat history storage
202
202
  chat_histories = {}
203
-
203
+
204
204
  def get_platform_instance(platform_name: str, model_name: str):
205
205
  """Get or create a platform instance"""
206
206
  key = f"{platform_name}:{model_name}"
@@ -208,12 +208,12 @@ def service_command(args):
208
208
  platform = registry.create_platform(platform_name)
209
209
  if not platform:
210
210
  raise HTTPException(status_code=400, detail=f"Platform {platform_name} not found")
211
-
211
+
212
212
  platform.set_model_name(model_name)
213
213
  platform_instances[key] = platform
214
-
214
+
215
215
  return platform_instances[key]
216
-
216
+
217
217
  def log_conversation(conversation_id, messages, model, response=None):
218
218
  """Log conversation to file in plain text format."""
219
219
  timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
@@ -228,14 +228,14 @@ def service_command(args):
228
228
  f.write(f"{message['role']}: {message['content']}\n")
229
229
  if response:
230
230
  f.write(f"\nResponse:\n{response}\n")
231
-
231
+
232
232
  PrettyOutput.print(f"Conversation logged to {log_file}", OutputType.INFO)
233
-
233
+
234
234
  @app.get("/v1/models")
235
235
  async def list_models():
236
236
  """List available models for the specified platform in OpenAI-compatible format"""
237
237
  model_list = []
238
-
238
+
239
239
  # Only get models for the currently set platform
240
240
  if default_platform:
241
241
  try:
@@ -253,10 +253,10 @@ def service_command(args):
253
253
  })
254
254
  except Exception as e:
255
255
  print(f"Error getting models for {default_platform}: {str(e)}")
256
-
256
+
257
257
  # Return model list
258
258
  return {"object": "list", "data": model_list}
259
-
259
+
260
260
  @app.post("/v1/chat/completions")
261
261
  @app.options("/v1/chat/completions") # 添加 OPTIONS 方法支持
262
262
  async def create_chat_completion(request: ChatCompletionRequest):
@@ -264,10 +264,10 @@ def service_command(args):
264
264
  model = request.model
265
265
  messages = request.messages
266
266
  stream = request.stream
267
-
267
+
268
268
  # Generate a conversation ID if this is a new conversation
269
269
  conversation_id = str(uuid.uuid4())
270
-
270
+
271
271
  # Extract platform and model name
272
272
  if "/" in model:
273
273
  platform_name, model_name = model.split("/", 1)
@@ -277,34 +277,34 @@ def service_command(args):
277
277
  platform_name, model_name = default_platform, default_model
278
278
  else:
279
279
  platform_name, model_name = "oyi", model # Default to OYI platform
280
-
280
+
281
281
  # Get platform instance
282
282
  platform = get_platform_instance(platform_name, model_name)
283
-
283
+
284
284
  # Convert messages to text format for the platform
285
285
  message_text = ""
286
286
  for msg in messages:
287
287
  role = msg.role
288
288
  content = msg.content
289
-
289
+
290
290
  if role == "system":
291
291
  message_text += f"System: {content}\n\n"
292
292
  elif role == "user":
293
293
  message_text += f"User: {content}\n\n"
294
294
  elif role == "assistant":
295
295
  message_text += f"Assistant: {content}\n\n"
296
-
296
+
297
297
  # Store messages in chat history
298
298
  chat_histories[conversation_id] = {
299
299
  "model": model,
300
300
  "messages": [{"role": m.role, "content": m.content} for m in messages]
301
301
  }
302
-
302
+
303
303
  # Log the conversation
304
- log_conversation(conversation_id,
305
- [{"role": m.role, "content": m.content} for m in messages],
304
+ log_conversation(conversation_id,
305
+ [{"role": m.role, "content": m.content} for m in messages],
306
306
  model)
307
-
307
+
308
308
  if stream:
309
309
  # Return streaming response
310
310
  return StreamingResponse(
@@ -315,23 +315,23 @@ def service_command(args):
315
315
  # Get chat response
316
316
  try:
317
317
  response_text = platform.chat_until_success(message_text)
318
-
318
+
319
319
  # Create response in OpenAI format
320
320
  completion_id = f"chatcmpl-{str(uuid.uuid4())}"
321
-
321
+
322
322
  # Update chat history with response
323
323
  if conversation_id in chat_histories:
324
324
  chat_histories[conversation_id]["messages"].append({
325
325
  "role": "assistant",
326
326
  "content": response_text
327
327
  })
328
-
328
+
329
329
  # Log the conversation with response
330
- log_conversation(conversation_id,
331
- chat_histories[conversation_id]["messages"],
330
+ log_conversation(conversation_id,
331
+ chat_histories[conversation_id]["messages"],
332
332
  model,
333
333
  response_text)
334
-
334
+
335
335
  return {
336
336
  "id": completion_id,
337
337
  "object": "chat.completion",
@@ -355,7 +355,7 @@ def service_command(args):
355
355
  }
356
356
  except Exception as e:
357
357
  raise HTTPException(status_code=500, detail=str(e))
358
-
358
+
359
359
  async def stream_chat_response(platform, message, model_name):
360
360
  """Stream chat response in OpenAI-compatible format"""
361
361
  import time
@@ -363,15 +363,15 @@ def service_command(args):
363
363
  import uuid
364
364
  from datetime import datetime
365
365
  import os
366
-
366
+
367
367
  completion_id = f"chatcmpl-{str(uuid.uuid4())}"
368
368
  created_time = int(time.time())
369
369
  conversation_id = str(uuid.uuid4())
370
-
370
+
371
371
  # Create logs directory if it doesn't exist
372
372
  logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "logs")
373
373
  os.makedirs(logs_dir, exist_ok=True)
374
-
374
+
375
375
  # 修改第一个yield语句的格式
376
376
  initial_data = {
377
377
  'id': completion_id,
@@ -386,14 +386,14 @@ def service_command(args):
386
386
  }
387
387
  res = json.dumps(initial_data)
388
388
  yield f"data: {res}\n\n"
389
-
389
+
390
390
  try:
391
391
  # 直接获取聊天响应,而不是尝试捕获stdout
392
392
  response = platform.chat_until_success(message)
393
-
393
+
394
394
  # 记录完整响应
395
395
  full_response = ""
396
-
396
+
397
397
  # 如果有响应,将其分块发送
398
398
  if response:
399
399
  # 分成小块以获得更好的流式体验
@@ -401,7 +401,7 @@ def service_command(args):
401
401
  for i in range(0, len(response), chunk_size):
402
402
  chunk = response[i:i+chunk_size]
403
403
  full_response += chunk
404
-
404
+
405
405
  # 创建并发送块
406
406
  chunk_data = {
407
407
  'id': completion_id,
@@ -414,9 +414,9 @@ def service_command(args):
414
414
  'finish_reason': None
415
415
  }]
416
416
  }
417
-
417
+
418
418
  yield f"data: {json.dumps(chunk_data)}\n\n"
419
-
419
+
420
420
  # 小延迟以模拟流式传输
421
421
  await asyncio.sleep(0.01)
422
422
  else:
@@ -434,7 +434,7 @@ def service_command(args):
434
434
  }
435
435
  yield f"data: {json.dumps(chunk_data)}\n\n"
436
436
  full_response = "No response from model."
437
-
437
+
438
438
  # 修改最终yield语句的格式
439
439
  final_data = {
440
440
  'id': completion_id,
@@ -448,14 +448,14 @@ def service_command(args):
448
448
  }]
449
449
  }
450
450
  yield f"data: {json.dumps(final_data)}\n\n"
451
-
451
+
452
452
  # 发送[DONE]标记
453
453
  yield "data: [DONE]\n\n"
454
-
454
+
455
455
  # 记录对话到文件
456
456
  timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
457
457
  log_file = os.path.join(logs_dir, f"stream_conversation_{conversation_id}_{timestamp}.json")
458
-
458
+
459
459
  log_data = {
460
460
  "conversation_id": conversation_id,
461
461
  "timestamp": timestamp,
@@ -463,12 +463,12 @@ def service_command(args):
463
463
  "message": message,
464
464
  "response": full_response
465
465
  }
466
-
466
+
467
467
  with open(log_file, "w", encoding="utf-8", errors="ignore") as f:
468
468
  json.dump(log_data, f, ensure_ascii=False, indent=2)
469
-
469
+
470
470
  PrettyOutput.print(f"Stream conversation logged to {log_file}", OutputType.INFO)
471
-
471
+
472
472
  except Exception as e:
473
473
  # 发送错误消息
474
474
  error_msg = f"Error: {str(e)}"
@@ -488,11 +488,11 @@ def service_command(args):
488
488
  yield f"data: {res}\n\n"
489
489
  yield f"data: {json.dumps({'error': {'message': error_msg, 'type': 'server_error'}})}\n\n"
490
490
  yield "data: [DONE]\n\n"
491
-
491
+
492
492
  # 记录错误到文件
493
493
  timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
494
494
  log_file = os.path.join(logs_dir, f"stream_error_{conversation_id}_{timestamp}.json")
495
-
495
+
496
496
  log_data = {
497
497
  "conversation_id": conversation_id,
498
498
  "timestamp": timestamp,
@@ -500,12 +500,12 @@ def service_command(args):
500
500
  "message": message,
501
501
  "error": error_msg
502
502
  }
503
-
503
+
504
504
  with open(log_file, "w", encoding="utf-8", errors="ignore") as f:
505
505
  json.dump(log_data, f, ensure_ascii=False, indent=2)
506
-
506
+
507
507
  PrettyOutput.print(f"Stream error logged to {log_file}", OutputType.ERROR)
508
-
508
+
509
509
  # Run the server
510
510
  uvicorn.run(app, host=host, port=port)
511
511
 
@@ -514,27 +514,27 @@ def main():
514
514
  import argparse
515
515
 
516
516
  init_env()
517
-
517
+
518
518
  parser = argparse.ArgumentParser(description='Jarvis AI 平台')
519
519
  subparsers = parser.add_subparsers(dest='command', help='可用子命令')
520
-
520
+
521
521
  # info subcommand
522
522
  info_parser = subparsers.add_parser('info', help='显示支持的平台和模型信息')
523
-
523
+
524
524
  # chat subcommand
525
525
  chat_parser = subparsers.add_parser('chat', help='与指定平台和模型聊天')
526
526
  chat_parser.add_argument('--platform', '-p', help='指定要使用的平台')
527
527
  chat_parser.add_argument('--model', '-m', help='指定要使用的模型')
528
-
528
+
529
529
  # service subcommand
530
530
  service_parser = subparsers.add_parser('service', help='启动OpenAI兼容的API服务')
531
531
  service_parser.add_argument('--host', default='127.0.0.1', help='服务主机地址 (默认: 127.0.0.1)')
532
532
  service_parser.add_argument('--port', type=int, default=8000, help='服务端口 (默认: 8000)')
533
533
  service_parser.add_argument('--platform', '-p', help='指定默认平台,当客户端未指定平台时使用')
534
534
  service_parser.add_argument('--model', '-m', help='指定默认模型,当客户端未指定模型时使用')
535
-
535
+
536
536
  args = parser.parse_args()
537
-
537
+
538
538
  if args.command == 'info':
539
539
  info_command(args)
540
540
  elif args.command == 'chat':
@@ -13,10 +13,10 @@ def test_chat(api_base, model, stream=False, interactive=False):
13
13
  api_key="dummy-key", # Not actually used by our service
14
14
  base_url=f"{api_base}/v1"
15
15
  )
16
-
16
+
17
17
  print(f"Testing chat with model: {model}, stream={stream}")
18
18
  print("=" * 50)
19
-
19
+
20
20
  try:
21
21
  # First, list available models
22
22
  print("Available models:")
@@ -24,35 +24,35 @@ def test_chat(api_base, model, stream=False, interactive=False):
24
24
  for m in models.data:
25
25
  print(f" - {m.id}")
26
26
  print()
27
-
27
+
28
28
  if interactive:
29
29
  # Interactive chat mode
30
30
  messages = [
31
31
  {"role": "system", "content": "You are a helpful assistant."}
32
32
  ]
33
-
33
+
34
34
  print("Interactive chat mode. Type 'exit' to quit.")
35
35
  print("=" * 50)
36
-
36
+
37
37
  while True:
38
38
  # Get user input
39
39
  user_input = input("You: ")
40
40
  if user_input.lower() in ['exit', 'quit', 'bye']:
41
41
  break
42
-
42
+
43
43
  # Add user message to history
44
44
  messages.append({"role": "user", "content": user_input})
45
-
45
+
46
46
  # Get response
47
47
  print("Assistant: ", end="", flush=True)
48
-
48
+
49
49
  if stream:
50
50
  response = client.chat.completions.create(
51
51
  model=model,
52
52
  messages=messages, # type: ignore
53
53
  stream=True
54
54
  ) # type: ignore
55
-
55
+
56
56
  # Process the streaming response
57
57
  assistant_response = ""
58
58
  for chunk in response:
@@ -68,14 +68,14 @@ def test_chat(api_base, model, stream=False, interactive=False):
68
68
  )
69
69
  assistant_response = response.choices[0].message.content
70
70
  print(assistant_response)
71
-
71
+
72
72
  # Add assistant response to history
73
73
  messages.append({"role": "assistant", "content": assistant_response}) # type: ignore
74
74
  print()
75
-
75
+
76
76
  print("=" * 50)
77
77
  print("Chat session ended.")
78
-
78
+
79
79
  else:
80
80
  # Single request mode
81
81
  print("Sending chat request...")
@@ -83,17 +83,17 @@ def test_chat(api_base, model, stream=False, interactive=False):
83
83
  {"role": "system", "content": "You are a helpful assistant."},
84
84
  {"role": "user", "content": "Hello! Tell me a short joke."}
85
85
  ]
86
-
86
+
87
87
  if stream:
88
88
  print("Response (streaming):")
89
-
89
+
90
90
  # Use the OpenAI client for streaming
91
91
  response = client.chat.completions.create(
92
92
  model=model,
93
93
  messages=messages, # type: ignore
94
94
  stream=True
95
95
  ) # type: ignore
96
-
96
+
97
97
  # Process the streaming response
98
98
  full_content = ""
99
99
  for chunk in response:
@@ -101,7 +101,7 @@ def test_chat(api_base, model, stream=False, interactive=False):
101
101
  content = chunk.choices[0].delta.content
102
102
  full_content += content
103
103
  print(content, end="", flush=True)
104
-
104
+
105
105
  print("\n")
106
106
  print(f"Full response: {full_content}")
107
107
  else:
@@ -111,16 +111,16 @@ def test_chat(api_base, model, stream=False, interactive=False):
111
111
  messages=messages # type: ignore
112
112
  )
113
113
  print(response.choices[0].message.content)
114
-
114
+
115
115
  print("=" * 50)
116
116
  print("Test completed successfully!")
117
-
117
+
118
118
  except Exception as e:
119
119
  print(f"Error: {str(e)}")
120
120
  import traceback
121
121
  traceback.print_exc()
122
122
  return 1
123
-
123
+
124
124
  return 0
125
125
 
126
126
  def main():
@@ -129,9 +129,9 @@ def main():
129
129
  parser.add_argument("--model", default="gpt-3.5-turbo", help="Model to test (default: gpt-3.5-turbo)")
130
130
  parser.add_argument("--stream", action="store_true", help="Test streaming mode")
131
131
  parser.add_argument("--interactive", "-i", action="store_true", help="Interactive chat mode")
132
-
132
+
133
133
  args = parser.parse_args()
134
-
134
+
135
135
  return test_chat(args.api_base, args.model, args.stream, args.interactive)
136
136
 
137
137
  if __name__ == "__main__":