auto-coder 0.1.256__py3-none-any.whl → 0.1.257__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -0,0 +1,192 @@
1
+ from prompt_toolkit.shortcuts import radiolist_dialog, input_dialog
2
+ from prompt_toolkit.validation import Validator, ValidationError
3
+ from rich.console import Console
4
+ from typing import Optional, Dict, Any, List
5
+ from autocoder.common.printer import Printer
6
+ import re
7
+ from pydantic import BaseModel
8
+
9
+ from autocoder.models import process_api_key_path
10
+
11
+ class ProviderInfo(BaseModel):
12
+ name: str
13
+ endpoint: str
14
+ r1_model: str
15
+ v3_model: str
16
+ api_key: str
17
+ r1_input_price: float
18
+ r1_output_price: float
19
+ v3_input_price: float
20
+ v3_output_price: float
21
+
22
+
23
+ PROVIDER_INFO_LIST = [
24
+ ProviderInfo(
25
+ name="volcano",
26
+ endpoint="https://ark.cn-beijing.volces.com/api/v3",
27
+ r1_model="",
28
+ v3_model="",
29
+ api_key="",
30
+ r1_input_price=2.0,
31
+ r1_output_price=8.0,
32
+ v3_input_price=1.0,
33
+ v3_output_price=4.0,
34
+ ),
35
+ ProviderInfo(
36
+ name="siliconFlow",
37
+ endpoint="https://api.siliconflow.cn/v1",
38
+ r1_model="Pro/deepseek-ai/DeepSeek-R1",
39
+ v3_model="Pro/deepseek-ai/DeepSeek-V3",
40
+ api_key="",
41
+ r1_input_price=2.0,
42
+ r1_output_price=4.0,
43
+ v3_input_price=4.0,
44
+ v3_output_price=16.0,
45
+ ),
46
+ ProviderInfo(
47
+ name="deepseek",
48
+ endpoint="https://api.deepseek.com/v1",
49
+ r1_model="deepseek-reasoner",
50
+ v3_model="deepseek-chat",
51
+ api_key="",
52
+ r1_input_price=4.0,
53
+ r1_output_price=16.0,
54
+ v3_input_price=2.0,
55
+ v3_output_price=8.0,
56
+ ),
57
+ ]
58
+
59
+ class VolcanoEndpointValidator(Validator):
60
+ def validate(self, document):
61
+ text = document.text
62
+ pattern = r'^ep-\d{14}-[a-z0-9]{5}$'
63
+ if not re.match(pattern, text):
64
+ raise ValidationError(
65
+ message='Invalid endpoint format. Should be like: ep-20250204215011-vzbsg',
66
+ cursor_position=len(text)
67
+ )
68
+
69
+ class ModelProviderSelector:
70
+ def __init__(self):
71
+ self.printer = Printer()
72
+ self.console = Console()
73
+
74
+ def to_models_json(self, provider_info: ProviderInfo) -> List[Dict[str, Any]]:
75
+ """
76
+ Convert provider info to models.json format.
77
+ Returns a list of model configurations matching the format in models.py default_models_list.
78
+
79
+ Args:
80
+ provider_info: ProviderInfo object containing provider details
81
+
82
+ Returns:
83
+ List[Dict[str, Any]]: List of model configurations
84
+ """
85
+ models = []
86
+
87
+ # Add R1 model (for reasoning/design/review)
88
+ if provider_info.r1_model:
89
+ models.append({
90
+ "name": f"r1_chat",
91
+ "description": f"{provider_info.name} R1 is for design/review",
92
+ "model_name": provider_info.r1_model,
93
+ "model_type": "saas/openai",
94
+ "base_url": provider_info.endpoint,
95
+ "api_key": provider_info.api_key,
96
+ "api_key_path": f"r1_chat",
97
+ "is_reasoning": True,
98
+ "input_price": provider_info.r1_input_price,
99
+ "output_price": provider_info.r1_output_price,
100
+ "average_speed": 0.0
101
+ })
102
+
103
+ # Add V3 model (for coding)
104
+ if provider_info.v3_model:
105
+ models.append({
106
+ "name": f"v3_chat",
107
+ "description": f"{provider_info.name} Chat is for coding",
108
+ "model_name": provider_info.v3_model,
109
+ "model_type": "saas/openai",
110
+ "base_url": provider_info.endpoint,
111
+ "api_key": provider_info.api_key,
112
+ "api_key_path": f"v3_chat",
113
+ "is_reasoning": False,
114
+ "input_price": provider_info.v3_input_price,
115
+ "output_price": provider_info.v3_output_price,
116
+ "average_speed": 0.0
117
+ })
118
+
119
+ return models
120
+
121
+ def select_provider(self) -> Optional[Dict[str, Any]]:
122
+ """
123
+ Let user select a model provider and input necessary credentials.
124
+ Returns a dictionary with provider info or None if cancelled.
125
+ """
126
+ result = radiolist_dialog(
127
+ title=self.printer.get_message_from_key("model_provider_select_title"),
128
+ text=self.printer.get_message_from_key("model_provider_select_text"),
129
+ values=[
130
+ ("volcano", self.printer.get_message_from_key("model_provider_volcano")),
131
+ ("siliconflow", self.printer.get_message_from_key("model_provider_guiji")),
132
+ ("deepseek", self.printer.get_message_from_key("model_provider_deepseek"))
133
+ ]
134
+ ).run()
135
+
136
+ if result is None:
137
+ return None
138
+
139
+
140
+ provider_info = None
141
+ for provider in PROVIDER_INFO_LIST:
142
+ if provider.name == result:
143
+ provider_info = provider
144
+ break
145
+
146
+ if result == "volcano":
147
+ # Get R1 endpoint
148
+ r1_endpoint = input_dialog(
149
+ title=self.printer.get_message_from_key("model_provider_api_key_title"),
150
+ text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
151
+ validator=VolcanoEndpointValidator()
152
+ ).run()
153
+
154
+ if r1_endpoint is None:
155
+ return None
156
+
157
+ provider_info.r1_model = r1_endpoint
158
+
159
+ # Get V3 endpoint
160
+ v3_endpoint = input_dialog(
161
+ title=self.printer.get_message_from_key("model_provider_api_key_title"),
162
+ text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
163
+ validator=VolcanoEndpointValidator()
164
+ ).run()
165
+
166
+ if v3_endpoint is None:
167
+ return None
168
+
169
+ provider_info.v3_model = v3_endpoint
170
+
171
+ # Get API key for all providers
172
+ api_key = input_dialog(
173
+ title=self.printer.get_message_from_key("model_provider_api_key_title"),
174
+ text=self.printer.get_message_from_key(f"model_provider_{result}_api_key_text"),
175
+ password=True
176
+ ).run()
177
+
178
+ if api_key is None:
179
+ return None
180
+
181
+ provider_info.api_key = api_key
182
+
183
+ self.printer.print_panel(
184
+ self.printer.get_message_from_key("model_provider_selected"),
185
+ text_options={"justify": "left"},
186
+ panel_options={
187
+ "title": self.printer.get_message_from_key("model_provider_success_title"),
188
+ "border_style": "green"
189
+ }
190
+ )
191
+
192
+ return provider_info
autocoder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.256"
1
+ __version__ = "0.1.257"