auto-coder 0.1.256__py3-none-any.whl → 0.1.258__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -0,0 +1,214 @@
1
+ from prompt_toolkit.shortcuts import radiolist_dialog, input_dialog
2
+ from prompt_toolkit.validation import Validator, ValidationError
3
+ from prompt_toolkit.styles import Style
4
+ from rich.console import Console
5
+ from typing import Optional, Dict, Any, List
6
+ from autocoder.common.printer import Printer
7
+ import re
8
+ from pydantic import BaseModel
9
+
10
+ from autocoder.models import process_api_key_path
11
+
12
+ class ProviderInfo(BaseModel):
13
+ name: str
14
+ endpoint: str
15
+ r1_model: str
16
+ v3_model: str
17
+ api_key: str
18
+ r1_input_price: float
19
+ r1_output_price: float
20
+ v3_input_price: float
21
+ v3_output_price: float
22
+
23
+
24
+ PROVIDER_INFO_LIST = [
25
+ ProviderInfo(
26
+ name="volcano",
27
+ endpoint="https://ark.cn-beijing.volces.com/api/v3",
28
+ r1_model="",
29
+ v3_model="",
30
+ api_key="",
31
+ r1_input_price=2.0,
32
+ r1_output_price=8.0,
33
+ v3_input_price=1.0,
34
+ v3_output_price=4.0,
35
+ ),
36
+ ProviderInfo(
37
+ name="siliconflow",
38
+ endpoint="https://api.siliconflow.cn/v1",
39
+ r1_model="Pro/deepseek-ai/DeepSeek-R1",
40
+ v3_model="Pro/deepseek-ai/DeepSeek-V3",
41
+ api_key="",
42
+ r1_input_price=2.0,
43
+ r1_output_price=4.0,
44
+ v3_input_price=4.0,
45
+ v3_output_price=16.0,
46
+ ),
47
+ ProviderInfo(
48
+ name="deepseek",
49
+ endpoint="https://api.deepseek.com/v1",
50
+ r1_model="deepseek-reasoner",
51
+ v3_model="deepseek-chat",
52
+ api_key="",
53
+ r1_input_price=4.0,
54
+ r1_output_price=16.0,
55
+ v3_input_price=2.0,
56
+ v3_output_price=8.0,
57
+ ),
58
+ ]
59
+
60
+ dialog_style = Style.from_dict({
61
+ 'dialog': 'bg:#2b2b2b',
62
+ 'dialog frame.label': 'bg:#2b2b2b #ffffff',
63
+ 'dialog.body': 'bg:#2b2b2b #ffffff',
64
+ 'dialog shadow': 'bg:#1f1f1f',
65
+ 'button': 'bg:#2b2b2b #808080',
66
+ 'button.focused': 'bg:#1e1e1e #ffd700 bold',
67
+ 'checkbox': '#e6e6e6',
68
+ 'checkbox-selected': '#0078d4',
69
+ 'radio-selected': 'bg:#0078d4 #ffffff',
70
+ 'dialog frame.border': '#0078d4',
71
+ 'radio': 'bg:#2b2b2b #808080',
72
+ 'radio.focused': 'bg:#1e1e1e #ffd700 bold'
73
+ })
74
+
75
+ class VolcanoEndpointValidator(Validator):
76
+ def validate(self, document):
77
+ text = document.text
78
+ pattern = r'^ep-\d{14}-[a-z0-9]{5}$'
79
+ if not re.match(pattern, text):
80
+ raise ValidationError(
81
+ message='Invalid endpoint format. Should be like: ep-20250204215011-vzbsg',
82
+ cursor_position=len(text)
83
+ )
84
+
85
+ class ModelProviderSelector:
86
+ def __init__(self):
87
+ self.printer = Printer()
88
+ self.console = Console()
89
+
90
+ def to_models_json(self, provider_info: ProviderInfo) -> List[Dict[str, Any]]:
91
+ """
92
+ Convert provider info to models.json format.
93
+ Returns a list of model configurations matching the format in models.py default_models_list.
94
+
95
+ Args:
96
+ provider_info: ProviderInfo object containing provider details
97
+
98
+ Returns:
99
+ List[Dict[str, Any]]: List of model configurations
100
+ """
101
+ models = []
102
+
103
+ # Add R1 model (for reasoning/design/review)
104
+ if provider_info.r1_model:
105
+ models.append({
106
+ "name": f"r1_chat",
107
+ "description": f"{provider_info.name} R1 is for design/review",
108
+ "model_name": provider_info.r1_model,
109
+ "model_type": "saas/openai",
110
+ "base_url": provider_info.endpoint,
111
+ "api_key": provider_info.api_key,
112
+ "api_key_path": f"r1_chat",
113
+ "is_reasoning": True,
114
+ "input_price": provider_info.r1_input_price,
115
+ "output_price": provider_info.r1_output_price,
116
+ "average_speed": 0.0
117
+ })
118
+
119
+ # Add V3 model (for coding)
120
+ if provider_info.v3_model:
121
+ models.append({
122
+ "name": f"v3_chat",
123
+ "description": f"{provider_info.name} Chat is for coding",
124
+ "model_name": provider_info.v3_model,
125
+ "model_type": "saas/openai",
126
+ "base_url": provider_info.endpoint,
127
+ "api_key": provider_info.api_key,
128
+ "api_key_path": f"v3_chat",
129
+ "is_reasoning": False,
130
+ "input_price": provider_info.v3_input_price,
131
+ "output_price": provider_info.v3_output_price,
132
+ "average_speed": 0.0
133
+ })
134
+
135
+ return models
136
+
137
+ def select_provider(self) -> Optional[Dict[str, Any]]:
138
+ """
139
+ Let user select a model provider and input necessary credentials.
140
+ Returns a dictionary with provider info or None if cancelled.
141
+ """
142
+
143
+
144
+ result = radiolist_dialog(
145
+ title=self.printer.get_message_from_key("model_provider_select_title"),
146
+ text=self.printer.get_message_from_key("model_provider_select_text"),
147
+ values=[
148
+ ("volcano", self.printer.get_message_from_key("model_provider_volcano")),
149
+ ("siliconflow", self.printer.get_message_from_key("model_provider_siliconflow")),
150
+ ("deepseek", self.printer.get_message_from_key("model_provider_deepseek"))
151
+ ],
152
+ style=dialog_style
153
+ ).run()
154
+
155
+ if result is None:
156
+ return None
157
+
158
+
159
+ provider_info = None
160
+ for provider in PROVIDER_INFO_LIST:
161
+ if provider.name == result:
162
+ provider_info = provider
163
+ break
164
+
165
+ if result == "volcano":
166
+ # Get R1 endpoint
167
+ r1_endpoint = input_dialog(
168
+ title=self.printer.get_message_from_key("model_provider_api_key_title"),
169
+ text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
170
+ validator=VolcanoEndpointValidator(),
171
+ style=dialog_style
172
+ ).run()
173
+
174
+ if r1_endpoint is None:
175
+ return None
176
+
177
+ provider_info.r1_model = r1_endpoint
178
+
179
+ # Get V3 endpoint
180
+ v3_endpoint = input_dialog(
181
+ title=self.printer.get_message_from_key("model_provider_api_key_title"),
182
+ text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
183
+ validator=VolcanoEndpointValidator(),
184
+ style=dialog_style
185
+ ).run()
186
+
187
+ if v3_endpoint is None:
188
+ return None
189
+
190
+ provider_info.v3_model = v3_endpoint
191
+
192
+ # Get API key for all providers
193
+ api_key = input_dialog(
194
+ title=self.printer.get_message_from_key("model_provider_api_key_title"),
195
+ text=self.printer.get_message_from_key(f"model_provider_{result}_api_key_text"),
196
+ password=True,
197
+ style=dialog_style
198
+ ).run()
199
+
200
+ if api_key is None:
201
+ return None
202
+
203
+ provider_info.api_key = api_key
204
+
205
+ self.printer.print_panel(
206
+ self.printer.get_message_from_key("model_provider_selected"),
207
+ text_options={"justify": "left"},
208
+ panel_options={
209
+ "title": self.printer.get_message_from_key("model_provider_success_title"),
210
+ "border_style": "green"
211
+ }
212
+ )
213
+
214
+ return provider_info
autocoder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.256"
1
+ __version__ = "0.1.258"