orch-code 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +12 -0
- package/LICENSE +21 -0
- package/README.md +624 -0
- package/cmd/apply.go +111 -0
- package/cmd/auth.go +393 -0
- package/cmd/auth_test.go +100 -0
- package/cmd/diff.go +57 -0
- package/cmd/doctor.go +149 -0
- package/cmd/explain.go +192 -0
- package/cmd/explain_test.go +62 -0
- package/cmd/init.go +100 -0
- package/cmd/interactive.go +1372 -0
- package/cmd/interactive_input.go +45 -0
- package/cmd/interactive_input_test.go +55 -0
- package/cmd/logs.go +72 -0
- package/cmd/model.go +84 -0
- package/cmd/plan.go +149 -0
- package/cmd/provider.go +189 -0
- package/cmd/provider_model_doctor_test.go +91 -0
- package/cmd/root.go +67 -0
- package/cmd/run.go +123 -0
- package/cmd/run_engine.go +208 -0
- package/cmd/run_engine_test.go +30 -0
- package/cmd/session.go +589 -0
- package/cmd/session_helpers.go +54 -0
- package/cmd/session_integration_test.go +30 -0
- package/cmd/session_list_current_test.go +87 -0
- package/cmd/session_messages_test.go +163 -0
- package/cmd/session_runs_test.go +68 -0
- package/cmd/sprint1_integration_test.go +119 -0
- package/cmd/stats.go +173 -0
- package/cmd/stats_test.go +71 -0
- package/cmd/version.go +4 -0
- package/go.mod +45 -0
- package/go.sum +108 -0
- package/internal/agents/agent.go +31 -0
- package/internal/agents/coder.go +167 -0
- package/internal/agents/planner.go +155 -0
- package/internal/agents/reviewer.go +118 -0
- package/internal/agents/runtime.go +25 -0
- package/internal/agents/runtime_test.go +77 -0
- package/internal/auth/account.go +78 -0
- package/internal/auth/oauth.go +523 -0
- package/internal/auth/store.go +287 -0
- package/internal/confidence/policy.go +174 -0
- package/internal/confidence/policy_test.go +71 -0
- package/internal/confidence/scorer.go +253 -0
- package/internal/confidence/scorer_test.go +83 -0
- package/internal/config/config.go +331 -0
- package/internal/config/config_defaults_test.go +138 -0
- package/internal/execution/contract_builder.go +160 -0
- package/internal/execution/contract_builder_test.go +68 -0
- package/internal/execution/plan_compliance.go +161 -0
- package/internal/execution/plan_compliance_test.go +71 -0
- package/internal/execution/retry_directive.go +132 -0
- package/internal/execution/scope_guard.go +69 -0
- package/internal/logger/logger.go +120 -0
- package/internal/models/contracts_test.go +100 -0
- package/internal/models/models.go +269 -0
- package/internal/orchestrator/orchestrator.go +701 -0
- package/internal/orchestrator/orchestrator_retry_test.go +135 -0
- package/internal/orchestrator/review_engine_test.go +50 -0
- package/internal/orchestrator/state.go +42 -0
- package/internal/orchestrator/test_classifier_test.go +68 -0
- package/internal/patch/applier.go +131 -0
- package/internal/patch/applier_test.go +25 -0
- package/internal/patch/parser.go +89 -0
- package/internal/patch/patch.go +60 -0
- package/internal/patch/summary.go +30 -0
- package/internal/patch/validator.go +104 -0
- package/internal/planning/normalizer.go +416 -0
- package/internal/planning/normalizer_test.go +64 -0
- package/internal/providers/errors.go +35 -0
- package/internal/providers/openai/client.go +498 -0
- package/internal/providers/openai/client_test.go +187 -0
- package/internal/providers/provider.go +47 -0
- package/internal/providers/registry.go +32 -0
- package/internal/providers/registry_test.go +57 -0
- package/internal/providers/router.go +52 -0
- package/internal/providers/state.go +114 -0
- package/internal/providers/state_test.go +64 -0
- package/internal/repo/analyzer.go +188 -0
- package/internal/repo/context.go +83 -0
- package/internal/review/engine.go +267 -0
- package/internal/review/engine_test.go +103 -0
- package/internal/runstore/store.go +137 -0
- package/internal/runstore/store_test.go +59 -0
- package/internal/runtime/lock.go +150 -0
- package/internal/runtime/lock_test.go +57 -0
- package/internal/session/compaction.go +260 -0
- package/internal/session/compaction_test.go +36 -0
- package/internal/session/service.go +117 -0
- package/internal/session/service_test.go +113 -0
- package/internal/storage/storage.go +1498 -0
- package/internal/storage/storage_test.go +413 -0
- package/internal/testing/classifier.go +80 -0
- package/internal/testing/classifier_test.go +36 -0
- package/internal/tools/command.go +160 -0
- package/internal/tools/command_test.go +56 -0
- package/internal/tools/file.go +111 -0
- package/internal/tools/git.go +77 -0
- package/internal/tools/invalid_params_test.go +36 -0
- package/internal/tools/policy.go +98 -0
- package/internal/tools/policy_test.go +36 -0
- package/internal/tools/registry_test.go +52 -0
- package/internal/tools/result.go +30 -0
- package/internal/tools/search.go +86 -0
- package/internal/tools/tool.go +94 -0
- package/main.go +9 -0
- package/npm/orch.js +25 -0
- package/package.json +41 -0
- package/scripts/changelog.js +20 -0
- package/scripts/check-release-version.js +21 -0
- package/scripts/lib/release-utils.js +223 -0
- package/scripts/postinstall.js +157 -0
- package/scripts/release.js +52 -0
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
package openai
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"bytes"
|
|
5
|
+
"context"
|
|
6
|
+
"encoding/base64"
|
|
7
|
+
"encoding/json"
|
|
8
|
+
"fmt"
|
|
9
|
+
"io"
|
|
10
|
+
"math/rand"
|
|
11
|
+
"net/http"
|
|
12
|
+
"os"
|
|
13
|
+
"strings"
|
|
14
|
+
"time"
|
|
15
|
+
|
|
16
|
+
"github.com/furkanbeydemir/orch/internal/config"
|
|
17
|
+
"github.com/furkanbeydemir/orch/internal/providers"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
type Client struct {
|
|
21
|
+
cfg config.OpenAIProviderConfig
|
|
22
|
+
httpClient *http.Client
|
|
23
|
+
rand *rand.Rand
|
|
24
|
+
resolveToken func(context.Context) (string, error)
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
type requester interface {
|
|
28
|
+
Do(req *http.Request) (*http.Response, error)
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
const (
|
|
32
|
+
defaultAPIBaseURL = "https://api.openai.com/v1"
|
|
33
|
+
defaultCodexBaseURL = "https://chatgpt.com/backend-api"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
func New(cfg config.OpenAIProviderConfig) *Client {
|
|
37
|
+
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
|
38
|
+
if timeout <= 0 {
|
|
39
|
+
timeout = 90 * time.Second
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
return &Client{
|
|
43
|
+
cfg: cfg,
|
|
44
|
+
httpClient: &http.Client{
|
|
45
|
+
Timeout: timeout,
|
|
46
|
+
},
|
|
47
|
+
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
|
48
|
+
resolveToken: func(ctx context.Context) (string, error) {
|
|
49
|
+
_ = ctx
|
|
50
|
+
return "", nil
|
|
51
|
+
},
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
func (c *Client) SetTokenResolver(resolver func(context.Context) (string, error)) {
|
|
56
|
+
if resolver == nil {
|
|
57
|
+
return
|
|
58
|
+
}
|
|
59
|
+
c.resolveToken = resolver
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
func (c *Client) Name() string {
|
|
63
|
+
return "openai"
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
func (c *Client) Validate(ctx context.Context) error {
|
|
67
|
+
key, err := c.resolveAuthToken(ctx)
|
|
68
|
+
if err != nil {
|
|
69
|
+
return err
|
|
70
|
+
}
|
|
71
|
+
mode := c.authMode()
|
|
72
|
+
if mode == "account" {
|
|
73
|
+
if _, accountErr := extractAccountID(key); accountErr != nil {
|
|
74
|
+
return &providers.Error{Code: providers.ErrAuthError, Message: "invalid account token", Cause: accountErr}
|
|
75
|
+
}
|
|
76
|
+
return nil
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.modelsURL(mode), nil)
|
|
80
|
+
if err != nil {
|
|
81
|
+
return err
|
|
82
|
+
}
|
|
83
|
+
if err := c.applyAuthHeaders(req, mode, key); err != nil {
|
|
84
|
+
return err
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
resp, err := c.httpClient.Do(req)
|
|
88
|
+
if err != nil {
|
|
89
|
+
return mapHTTPError(err, 0, "validate")
|
|
90
|
+
}
|
|
91
|
+
defer resp.Body.Close()
|
|
92
|
+
|
|
93
|
+
if resp.StatusCode >= 300 {
|
|
94
|
+
body, _ := io.ReadAll(resp.Body)
|
|
95
|
+
return mapStatusError(resp.StatusCode, string(body), "validate")
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
return nil
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
func (c *Client) Chat(ctx context.Context, req providers.ChatRequest) (providers.ChatResponse, error) {
|
|
102
|
+
return c.chatWithDoer(ctx, req, c.httpClient)
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
func (c *Client) Stream(ctx context.Context, req providers.ChatRequest) (<-chan providers.StreamEvent, <-chan error) {
|
|
106
|
+
stream := make(chan providers.StreamEvent, 1)
|
|
107
|
+
errCh := make(chan error, 1)
|
|
108
|
+
|
|
109
|
+
go func() {
|
|
110
|
+
defer close(stream)
|
|
111
|
+
defer close(errCh)
|
|
112
|
+
|
|
113
|
+
resp, err := c.Chat(ctx, req)
|
|
114
|
+
if err != nil {
|
|
115
|
+
errCh <- err
|
|
116
|
+
return
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
stream <- providers.StreamEvent{Type: "token", Text: resp.Text}
|
|
120
|
+
stream <- providers.StreamEvent{Type: "done", Metadata: map[string]string{"finish_reason": resp.FinishReason}}
|
|
121
|
+
}()
|
|
122
|
+
|
|
123
|
+
return stream, errCh
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
func (c *Client) chatWithDoer(ctx context.Context, req providers.ChatRequest, doer requester) (providers.ChatResponse, error) {
|
|
127
|
+
key, err := c.resolveAuthToken(ctx)
|
|
128
|
+
if err != nil {
|
|
129
|
+
return providers.ChatResponse{}, err
|
|
130
|
+
}
|
|
131
|
+
mode := c.authMode()
|
|
132
|
+
|
|
133
|
+
model := strings.TrimSpace(req.Model)
|
|
134
|
+
if model == "" {
|
|
135
|
+
model = c.defaultModel(req.Role)
|
|
136
|
+
}
|
|
137
|
+
if model == "" {
|
|
138
|
+
return providers.ChatResponse{}, &providers.Error{Code: providers.ErrModelUnavailable, Message: fmt.Sprintf("model not configured for role %s", req.Role)}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
payload := map[string]any{
|
|
142
|
+
"model": model,
|
|
143
|
+
"input": buildInput(req.SystemPrompt, req.UserPrompt),
|
|
144
|
+
}
|
|
145
|
+
if effort := strings.TrimSpace(req.ReasoningEffort); effort != "" {
|
|
146
|
+
payload["reasoning"] = map[string]any{"effort": effort}
|
|
147
|
+
} else if effort := strings.TrimSpace(c.cfg.ReasoningEffort); effort != "" {
|
|
148
|
+
payload["reasoning"] = map[string]any{"effort": effort}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
bodyBytes, err := json.Marshal(payload)
|
|
152
|
+
if err != nil {
|
|
153
|
+
return providers.ChatResponse{}, err
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
attempts := c.cfg.MaxRetries
|
|
157
|
+
if attempts <= 0 {
|
|
158
|
+
attempts = 3
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
var lastErr error
|
|
162
|
+
for attempt := 1; attempt <= attempts; attempt++ {
|
|
163
|
+
httpReq, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, c.responsesURL(mode), bytes.NewReader(bodyBytes))
|
|
164
|
+
if reqErr != nil {
|
|
165
|
+
return providers.ChatResponse{}, reqErr
|
|
166
|
+
}
|
|
167
|
+
if err := c.applyAuthHeaders(httpReq, mode, key); err != nil {
|
|
168
|
+
return providers.ChatResponse{}, err
|
|
169
|
+
}
|
|
170
|
+
httpReq.Header.Set("Content-Type", "application/json")
|
|
171
|
+
|
|
172
|
+
httpResp, doErr := doer.Do(httpReq)
|
|
173
|
+
if doErr != nil {
|
|
174
|
+
mapped := mapHTTPError(doErr, 0, "chat")
|
|
175
|
+
lastErr = mapped
|
|
176
|
+
if isRetryable(mapped) && attempt < attempts {
|
|
177
|
+
c.sleepBackoff(attempt)
|
|
178
|
+
continue
|
|
179
|
+
}
|
|
180
|
+
return providers.ChatResponse{}, mapped
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
data, readErr := io.ReadAll(httpResp.Body)
|
|
184
|
+
_ = httpResp.Body.Close()
|
|
185
|
+
if readErr != nil {
|
|
186
|
+
lastErr = &providers.Error{Code: providers.ErrInvalidResponse, Message: "failed to read provider response", Cause: readErr}
|
|
187
|
+
if attempt < attempts {
|
|
188
|
+
c.sleepBackoff(attempt)
|
|
189
|
+
continue
|
|
190
|
+
}
|
|
191
|
+
return providers.ChatResponse{}, lastErr
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
if httpResp.StatusCode >= 300 {
|
|
195
|
+
mapped := mapStatusError(httpResp.StatusCode, string(data), "chat")
|
|
196
|
+
lastErr = mapped
|
|
197
|
+
if isRetryable(mapped) && attempt < attempts {
|
|
198
|
+
c.sleepBackoff(attempt)
|
|
199
|
+
continue
|
|
200
|
+
}
|
|
201
|
+
return providers.ChatResponse{}, mapped
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
parsed, parseErr := parseResponse(data)
|
|
205
|
+
if parseErr != nil {
|
|
206
|
+
lastErr = &providers.Error{Code: providers.ErrInvalidResponse, Message: "failed to parse provider response", Cause: parseErr}
|
|
207
|
+
if attempt < attempts {
|
|
208
|
+
c.sleepBackoff(attempt)
|
|
209
|
+
continue
|
|
210
|
+
}
|
|
211
|
+
return providers.ChatResponse{}, lastErr
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
parsed.ProviderMetadata = map[string]string{
|
|
215
|
+
"provider": "openai",
|
|
216
|
+
"model": model,
|
|
217
|
+
}
|
|
218
|
+
return parsed, nil
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
if lastErr == nil {
|
|
222
|
+
lastErr = &providers.Error{Code: providers.ErrTransient, Message: "provider call failed"}
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
return providers.ChatResponse{}, lastErr
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
func (c *Client) defaultModel(role providers.Role) string {
|
|
229
|
+
switch role {
|
|
230
|
+
case providers.RolePlanner:
|
|
231
|
+
return c.cfg.Models.Planner
|
|
232
|
+
case providers.RoleCoder:
|
|
233
|
+
return c.cfg.Models.Coder
|
|
234
|
+
case providers.RoleReviewer:
|
|
235
|
+
return c.cfg.Models.Reviewer
|
|
236
|
+
default:
|
|
237
|
+
return ""
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
func buildInput(system, user string) []map[string]any {
|
|
242
|
+
parts := make([]map[string]any, 0, 2)
|
|
243
|
+
if strings.TrimSpace(system) != "" {
|
|
244
|
+
parts = append(parts, map[string]any{
|
|
245
|
+
"role": "system",
|
|
246
|
+
"content": []map[string]string{
|
|
247
|
+
{"type": "input_text", "text": system},
|
|
248
|
+
},
|
|
249
|
+
})
|
|
250
|
+
}
|
|
251
|
+
parts = append(parts, map[string]any{
|
|
252
|
+
"role": "user",
|
|
253
|
+
"content": []map[string]string{
|
|
254
|
+
{"type": "input_text", "text": user},
|
|
255
|
+
},
|
|
256
|
+
})
|
|
257
|
+
return parts
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
type responsesAPI struct {
|
|
261
|
+
Output []struct {
|
|
262
|
+
Content []struct {
|
|
263
|
+
Text string `json:"text"`
|
|
264
|
+
} `json:"content"`
|
|
265
|
+
} `json:"output"`
|
|
266
|
+
OutputText string `json:"output_text"`
|
|
267
|
+
Usage struct {
|
|
268
|
+
InputTokens int `json:"input_tokens"`
|
|
269
|
+
OutputTokens int `json:"output_tokens"`
|
|
270
|
+
TotalTokens int `json:"total_tokens"`
|
|
271
|
+
} `json:"usage"`
|
|
272
|
+
Status string `json:"status"`
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
func parseResponse(data []byte) (providers.ChatResponse, error) {
|
|
276
|
+
var raw responsesAPI
|
|
277
|
+
if err := json.Unmarshal(data, &raw); err != nil {
|
|
278
|
+
return providers.ChatResponse{}, err
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
text := strings.TrimSpace(raw.OutputText)
|
|
282
|
+
if text == "" {
|
|
283
|
+
var b strings.Builder
|
|
284
|
+
for _, out := range raw.Output {
|
|
285
|
+
for _, c := range out.Content {
|
|
286
|
+
if strings.TrimSpace(c.Text) != "" {
|
|
287
|
+
if b.Len() > 0 {
|
|
288
|
+
b.WriteString("\n")
|
|
289
|
+
}
|
|
290
|
+
b.WriteString(c.Text)
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
text = strings.TrimSpace(b.String())
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
if text == "" {
|
|
298
|
+
return providers.ChatResponse{}, fmt.Errorf("empty output text")
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
return providers.ChatResponse{
|
|
302
|
+
Text: text,
|
|
303
|
+
FinishReason: raw.Status,
|
|
304
|
+
Usage: providers.Usage{
|
|
305
|
+
InputTokens: raw.Usage.InputTokens,
|
|
306
|
+
OutputTokens: raw.Usage.OutputTokens,
|
|
307
|
+
TotalTokens: raw.Usage.TotalTokens,
|
|
308
|
+
},
|
|
309
|
+
}, nil
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
func (c *Client) sleepBackoff(attempt int) {
|
|
313
|
+
base := time.Duration(250*(1<<(attempt-1))) * time.Millisecond
|
|
314
|
+
jitter := time.Duration(c.rand.Intn(200)) * time.Millisecond
|
|
315
|
+
time.Sleep(base + jitter)
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
func mapHTTPError(err error, status int, op string) error {
|
|
319
|
+
if strings.Contains(strings.ToLower(err.Error()), "timeout") {
|
|
320
|
+
return &providers.Error{Code: providers.ErrTimeout, Message: fmt.Sprintf("%s timeout", op), Cause: err}
|
|
321
|
+
}
|
|
322
|
+
if status >= 500 {
|
|
323
|
+
return &providers.Error{Code: providers.ErrTransient, Message: fmt.Sprintf("%s transient failure", op), Cause: err}
|
|
324
|
+
}
|
|
325
|
+
return &providers.Error{Code: providers.ErrTransient, Message: fmt.Sprintf("%s request failed", op), Cause: err}
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
func mapStatusError(status int, body, op string) error {
|
|
329
|
+
trimmed := strings.TrimSpace(body)
|
|
330
|
+
switch {
|
|
331
|
+
case status == http.StatusUnauthorized || status == http.StatusForbidden:
|
|
332
|
+
return &providers.Error{Code: providers.ErrAuthError, Message: fmt.Sprintf("%s unauthorized", op), Cause: fmt.Errorf("status=%d body=%s", status, trimmed)}
|
|
333
|
+
case status == http.StatusTooManyRequests:
|
|
334
|
+
return &providers.Error{Code: providers.ErrRateLimited, Message: fmt.Sprintf("%s rate limited", op), Cause: fmt.Errorf("status=%d body=%s", status, trimmed)}
|
|
335
|
+
case status == http.StatusNotFound:
|
|
336
|
+
return &providers.Error{Code: providers.ErrModelUnavailable, Message: fmt.Sprintf("%s model unavailable", op), Cause: fmt.Errorf("status=%d body=%s", status, trimmed)}
|
|
337
|
+
case status >= 500:
|
|
338
|
+
return &providers.Error{Code: providers.ErrTransient, Message: fmt.Sprintf("%s transient error", op), Cause: fmt.Errorf("status=%d body=%s", status, trimmed)}
|
|
339
|
+
default:
|
|
340
|
+
return &providers.Error{Code: providers.ErrInvalidResponse, Message: fmt.Sprintf("%s invalid response", op), Cause: fmt.Errorf("status=%d body=%s", status, trimmed)}
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
func isRetryable(err error) bool {
|
|
345
|
+
pe, ok := err.(*providers.Error)
|
|
346
|
+
if !ok {
|
|
347
|
+
return false
|
|
348
|
+
}
|
|
349
|
+
switch pe.Code {
|
|
350
|
+
case providers.ErrRateLimited, providers.ErrTimeout, providers.ErrTransient:
|
|
351
|
+
return true
|
|
352
|
+
default:
|
|
353
|
+
return false
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
func (c *Client) resolveAuthToken(ctx context.Context) (string, error) {
|
|
358
|
+
mode := c.authMode()
|
|
359
|
+
if mode == "" {
|
|
360
|
+
mode = "api_key"
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
switch mode {
|
|
364
|
+
case "api_key":
|
|
365
|
+
key := strings.TrimSpace(os.Getenv(c.cfg.APIKeyEnv))
|
|
366
|
+
if key == "" && c.resolveToken != nil {
|
|
367
|
+
resolved, err := c.resolveToken(ctx)
|
|
368
|
+
if err != nil {
|
|
369
|
+
return "", &providers.Error{Code: providers.ErrAuthError, Message: "failed to resolve api key from local auth state", Cause: err}
|
|
370
|
+
}
|
|
371
|
+
key = strings.TrimSpace(resolved)
|
|
372
|
+
}
|
|
373
|
+
if key == "" {
|
|
374
|
+
return "", &providers.Error{Code: providers.ErrAuthError, Message: fmt.Sprintf("missing API key in env var %s and local auth state", c.cfg.APIKeyEnv)}
|
|
375
|
+
}
|
|
376
|
+
return key, nil
|
|
377
|
+
case "account":
|
|
378
|
+
if env := strings.TrimSpace(os.Getenv(c.cfg.AccountTokenEnv)); env != "" {
|
|
379
|
+
return env, nil
|
|
380
|
+
}
|
|
381
|
+
if c.resolveToken != nil {
|
|
382
|
+
token, err := c.resolveToken(ctx)
|
|
383
|
+
if err != nil {
|
|
384
|
+
return "", &providers.Error{Code: providers.ErrAuthError, Message: "failed to resolve account token", Cause: err}
|
|
385
|
+
}
|
|
386
|
+
if strings.TrimSpace(token) != "" {
|
|
387
|
+
return strings.TrimSpace(token), nil
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
return "", &providers.Error{Code: providers.ErrAuthError, Message: fmt.Sprintf("missing account token in env var %s and local auth state", c.cfg.AccountTokenEnv)}
|
|
391
|
+
default:
|
|
392
|
+
return "", &providers.Error{Code: providers.ErrAuthError, Message: fmt.Sprintf("unsupported auth mode: %s", mode)}
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
func (c *Client) authMode() string {
|
|
397
|
+
mode := strings.ToLower(strings.TrimSpace(c.cfg.AuthMode))
|
|
398
|
+
if mode == "" {
|
|
399
|
+
return "api_key"
|
|
400
|
+
}
|
|
401
|
+
return mode
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
func (c *Client) baseURLForMode(mode string) string {
|
|
405
|
+
base := strings.TrimSpace(c.cfg.BaseURL)
|
|
406
|
+
base = strings.TrimRight(base, "/")
|
|
407
|
+
if mode == "account" {
|
|
408
|
+
if base == "" || strings.EqualFold(base, strings.TrimRight(defaultAPIBaseURL, "/")) {
|
|
409
|
+
return defaultCodexBaseURL
|
|
410
|
+
}
|
|
411
|
+
return base
|
|
412
|
+
}
|
|
413
|
+
if base == "" {
|
|
414
|
+
return defaultAPIBaseURL
|
|
415
|
+
}
|
|
416
|
+
return base
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
func (c *Client) modelsURL(mode string) string {
|
|
420
|
+
base := c.baseURLForMode(mode)
|
|
421
|
+
if strings.HasSuffix(base, "/models") {
|
|
422
|
+
return base
|
|
423
|
+
}
|
|
424
|
+
return base + "/models"
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
func (c *Client) responsesURL(mode string) string {
|
|
428
|
+
base := c.baseURLForMode(mode)
|
|
429
|
+
if mode == "account" {
|
|
430
|
+
switch {
|
|
431
|
+
case strings.HasSuffix(base, "/codex/responses"):
|
|
432
|
+
return base
|
|
433
|
+
case strings.HasSuffix(base, "/codex"):
|
|
434
|
+
return base + "/responses"
|
|
435
|
+
default:
|
|
436
|
+
return base + "/codex/responses"
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
if strings.HasSuffix(base, "/responses") {
|
|
440
|
+
return base
|
|
441
|
+
}
|
|
442
|
+
return base + "/responses"
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
func (c *Client) applyAuthHeaders(req *http.Request, mode, token string) error {
|
|
446
|
+
req.Header.Set("Authorization", "Bearer "+token)
|
|
447
|
+
if mode != "account" {
|
|
448
|
+
return nil
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
accountID, err := extractAccountID(token)
|
|
452
|
+
if err != nil {
|
|
453
|
+
return &providers.Error{Code: providers.ErrAuthError, Message: "failed to extract account id from oauth token", Cause: err}
|
|
454
|
+
}
|
|
455
|
+
req.Header.Set("ChatGPT-Account-Id", accountID)
|
|
456
|
+
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
|
457
|
+
req.Header.Set("originator", "orch")
|
|
458
|
+
return nil
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
func extractAccountID(token string) (string, error) {
|
|
462
|
+
token = strings.TrimSpace(token)
|
|
463
|
+
parts := strings.Split(token, ".")
|
|
464
|
+
if len(parts) != 3 {
|
|
465
|
+
return "", fmt.Errorf("token is not a jwt")
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
469
|
+
if err != nil {
|
|
470
|
+
payload, err = base64.URLEncoding.DecodeString(parts[1])
|
|
471
|
+
if err != nil {
|
|
472
|
+
return "", fmt.Errorf("failed to decode jwt payload: %w", err)
|
|
473
|
+
}
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
claims := map[string]any{}
|
|
477
|
+
if err := json.Unmarshal(payload, &claims); err != nil {
|
|
478
|
+
return "", fmt.Errorf("failed to parse jwt payload: %w", err)
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
if id, ok := claims["chatgpt_account_id"].(string); ok && strings.TrimSpace(id) != "" {
|
|
482
|
+
return strings.TrimSpace(id), nil
|
|
483
|
+
}
|
|
484
|
+
if nested, ok := claims["https://api.openai.com/auth"].(map[string]any); ok {
|
|
485
|
+
if id, ok := nested["chatgpt_account_id"].(string); ok && strings.TrimSpace(id) != "" {
|
|
486
|
+
return strings.TrimSpace(id), nil
|
|
487
|
+
}
|
|
488
|
+
}
|
|
489
|
+
if organizations, ok := claims["organizations"].([]any); ok && len(organizations) > 0 {
|
|
490
|
+
if org, ok := organizations[0].(map[string]any); ok {
|
|
491
|
+
if id, ok := org["id"].(string); ok && strings.TrimSpace(id) != "" {
|
|
492
|
+
return strings.TrimSpace(id), nil
|
|
493
|
+
}
|
|
494
|
+
}
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
return "", fmt.Errorf("chatgpt_account_id claim not found")
|
|
498
|
+
}
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
package openai
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"encoding/base64"
|
|
6
|
+
"fmt"
|
|
7
|
+
"io"
|
|
8
|
+
"net/http"
|
|
9
|
+
"os"
|
|
10
|
+
"strings"
|
|
11
|
+
"testing"
|
|
12
|
+
|
|
13
|
+
"github.com/furkanbeydemir/orch/internal/config"
|
|
14
|
+
"github.com/furkanbeydemir/orch/internal/providers"
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
type sequenceDoer struct {
|
|
18
|
+
responses []*http.Response
|
|
19
|
+
index int
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
type inspectDoer struct {
|
|
23
|
+
fn func(req *http.Request) (*http.Response, error)
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
func (d *inspectDoer) Do(req *http.Request) (*http.Response, error) {
|
|
27
|
+
if d.fn == nil {
|
|
28
|
+
return nil, fmt.Errorf("no inspect fn configured")
|
|
29
|
+
}
|
|
30
|
+
return d.fn(req)
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
func (d *sequenceDoer) Do(req *http.Request) (*http.Response, error) {
|
|
34
|
+
if d.index >= len(d.responses) {
|
|
35
|
+
return nil, fmt.Errorf("no response configured")
|
|
36
|
+
}
|
|
37
|
+
resp := d.responses[d.index]
|
|
38
|
+
d.index++
|
|
39
|
+
return resp, nil
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
func TestChatRetriesOnRateLimit(t *testing.T) {
|
|
43
|
+
t.Setenv("OPENAI_API_KEY", "test-key")
|
|
44
|
+
|
|
45
|
+
client := New(config.OpenAIProviderConfig{
|
|
46
|
+
APIKeyEnv: "OPENAI_API_KEY",
|
|
47
|
+
BaseURL: "https://example.test/v1",
|
|
48
|
+
TimeoutSeconds: 5,
|
|
49
|
+
MaxRetries: 2,
|
|
50
|
+
Models: config.ProviderRoleModels{
|
|
51
|
+
Coder: "gpt-5.3-codex",
|
|
52
|
+
},
|
|
53
|
+
})
|
|
54
|
+
|
|
55
|
+
doer := &sequenceDoer{responses: []*http.Response{
|
|
56
|
+
response(http.StatusTooManyRequests, `{"error":"rate"}`),
|
|
57
|
+
response(http.StatusOK, `{"output_text":"done","status":"completed","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}`),
|
|
58
|
+
}}
|
|
59
|
+
|
|
60
|
+
out, err := client.chatWithDoer(context.Background(), providers.ChatRequest{Role: providers.RoleCoder}, doer)
|
|
61
|
+
if err != nil {
|
|
62
|
+
t.Fatalf("chat should succeed after retry: %v", err)
|
|
63
|
+
}
|
|
64
|
+
if strings.TrimSpace(out.Text) != "done" {
|
|
65
|
+
t.Fatalf("unexpected text: %q", out.Text)
|
|
66
|
+
}
|
|
67
|
+
if doer.index != 2 {
|
|
68
|
+
t.Fatalf("expected 2 attempts, got %d", doer.index)
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
func TestValidateMissingKey(t *testing.T) {
|
|
73
|
+
_ = os.Unsetenv("OPENAI_API_KEY")
|
|
74
|
+
client := New(config.OpenAIProviderConfig{APIKeyEnv: "OPENAI_API_KEY", BaseURL: "https://example.test/v1"})
|
|
75
|
+
err := client.Validate(context.Background())
|
|
76
|
+
if err == nil {
|
|
77
|
+
t.Fatalf("expected validate error")
|
|
78
|
+
}
|
|
79
|
+
perr, ok := err.(*providers.Error)
|
|
80
|
+
if !ok {
|
|
81
|
+
t.Fatalf("expected provider error type")
|
|
82
|
+
}
|
|
83
|
+
if perr.Code != providers.ErrAuthError {
|
|
84
|
+
t.Fatalf("unexpected error code: %s", perr.Code)
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
func TestMapStatusError(t *testing.T) {
|
|
89
|
+
err := mapStatusError(http.StatusUnauthorized, "bad", "chat")
|
|
90
|
+
perr, ok := err.(*providers.Error)
|
|
91
|
+
if !ok {
|
|
92
|
+
t.Fatalf("expected provider error")
|
|
93
|
+
}
|
|
94
|
+
if perr.Code != providers.ErrAuthError {
|
|
95
|
+
t.Fatalf("unexpected code: %s", perr.Code)
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
func TestResolveAuthTokenAccountModeWithResolver(t *testing.T) {
|
|
100
|
+
client := New(config.OpenAIProviderConfig{
|
|
101
|
+
AuthMode: "account",
|
|
102
|
+
AccountTokenEnv: "OPENAI_ACCOUNT_TOKEN",
|
|
103
|
+
})
|
|
104
|
+
client.SetTokenResolver(func(ctx context.Context) (string, error) {
|
|
105
|
+
return "account-token", nil
|
|
106
|
+
})
|
|
107
|
+
|
|
108
|
+
token, err := client.resolveAuthToken(context.Background())
|
|
109
|
+
if err != nil {
|
|
110
|
+
t.Fatalf("resolve token: %v", err)
|
|
111
|
+
}
|
|
112
|
+
if token != "account-token" {
|
|
113
|
+
t.Fatalf("unexpected token: %s", token)
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
func TestChatAccountModeUsesCodexEndpointAndAccountHeader(t *testing.T) {
|
|
118
|
+
client := New(config.OpenAIProviderConfig{
|
|
119
|
+
AuthMode: "account",
|
|
120
|
+
BaseURL: "https://api.openai.com/v1",
|
|
121
|
+
AccountTokenEnv: "OPENAI_ACCOUNT_TOKEN",
|
|
122
|
+
Models: config.ProviderRoleModels{
|
|
123
|
+
Coder: "gpt-5.3-codex",
|
|
124
|
+
},
|
|
125
|
+
})
|
|
126
|
+
client.SetTokenResolver(func(ctx context.Context) (string, error) {
|
|
127
|
+
return testAccountToken("acc-123"), nil
|
|
128
|
+
})
|
|
129
|
+
|
|
130
|
+
doer := &inspectDoer{fn: func(req *http.Request) (*http.Response, error) {
|
|
131
|
+
if got := req.URL.String(); got != "https://chatgpt.com/backend-api/codex/responses" {
|
|
132
|
+
return nil, fmt.Errorf("unexpected request url: %s", got)
|
|
133
|
+
}
|
|
134
|
+
if got := req.Header.Get("ChatGPT-Account-Id"); got != "acc-123" {
|
|
135
|
+
return nil, fmt.Errorf("unexpected account header: %s", got)
|
|
136
|
+
}
|
|
137
|
+
if got := req.Header.Get("Authorization"); !strings.HasPrefix(got, "Bearer ") {
|
|
138
|
+
return nil, fmt.Errorf("missing auth header")
|
|
139
|
+
}
|
|
140
|
+
return response(http.StatusOK, `{"output_text":"done","status":"completed","usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`), nil
|
|
141
|
+
}}
|
|
142
|
+
|
|
143
|
+
out, err := client.chatWithDoer(context.Background(), providers.ChatRequest{Role: providers.RoleCoder}, doer)
|
|
144
|
+
if err != nil {
|
|
145
|
+
t.Fatalf("chat error: %v", err)
|
|
146
|
+
}
|
|
147
|
+
if strings.TrimSpace(out.Text) != "done" {
|
|
148
|
+
t.Fatalf("unexpected text: %q", out.Text)
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
func TestValidateAccountModeRejectsNonJWTToken(t *testing.T) {
|
|
153
|
+
client := New(config.OpenAIProviderConfig{
|
|
154
|
+
AuthMode: "account",
|
|
155
|
+
AccountTokenEnv: "OPENAI_ACCOUNT_TOKEN",
|
|
156
|
+
})
|
|
157
|
+
client.SetTokenResolver(func(ctx context.Context) (string, error) {
|
|
158
|
+
return "not-a-jwt", nil
|
|
159
|
+
})
|
|
160
|
+
|
|
161
|
+
err := client.Validate(context.Background())
|
|
162
|
+
if err == nil {
|
|
163
|
+
t.Fatalf("expected validate error")
|
|
164
|
+
}
|
|
165
|
+
perr, ok := err.(*providers.Error)
|
|
166
|
+
if !ok {
|
|
167
|
+
t.Fatalf("expected provider error type")
|
|
168
|
+
}
|
|
169
|
+
if perr.Code != providers.ErrAuthError {
|
|
170
|
+
t.Fatalf("unexpected error code: %s", perr.Code)
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
func testAccountToken(accountID string) string {
|
|
175
|
+
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none"}`))
|
|
176
|
+
payload := fmt.Sprintf(`{"https://api.openai.com/auth":{"chatgpt_account_id":"%s"}}`, accountID)
|
|
177
|
+
body := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
|
178
|
+
return header + "." + body + ".sig"
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
func response(status int, body string) *http.Response {
|
|
182
|
+
return &http.Response{
|
|
183
|
+
StatusCode: status,
|
|
184
|
+
Body: io.NopCloser(strings.NewReader(body)),
|
|
185
|
+
Header: make(http.Header),
|
|
186
|
+
}
|
|
187
|
+
}
|