orch-code 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +3 -0
- package/README.md +6 -0
- package/cmd/auth.go +89 -10
- package/cmd/auth_test.go +57 -4
- package/cmd/doctor.go +13 -5
- package/cmd/interactive.go +14 -2
- package/cmd/version.go +1 -1
- package/internal/auth/account.go +27 -2
- package/internal/auth/account_session.go +154 -0
- package/internal/auth/account_session_test.go +71 -0
- package/internal/auth/store.go +477 -116
- package/internal/auth/store_test.go +73 -0
- package/internal/orchestrator/orchestrator.go +14 -2
- package/internal/providers/openai/client.go +83 -5
- package/internal/providers/openai/client_test.go +52 -0
- package/internal/providers/state_test.go +42 -0
- package/package.json +1 -1
package/CHANGELOG.md
CHANGED
package/README.md
CHANGED
|
@@ -327,8 +327,14 @@ Or account mode (OAuth):
|
|
|
327
327
|
|
|
328
328
|
```bash
|
|
329
329
|
./orch auth login openai --method account --flow auto
|
|
330
|
+
./orch auth login openai --method account --flow auto # add another account
|
|
331
|
+
./orch auth list
|
|
332
|
+
./orch auth use <credential-id>
|
|
333
|
+
./orch auth remove <credential-id>
|
|
330
334
|
```
|
|
331
335
|
|
|
336
|
+
When multiple OpenAI OAuth accounts are stored, Orch keeps one active account and can fail over to the next local account when the active one is rate-limited or rejected.
|
|
337
|
+
|
|
332
338
|
Validate runtime readiness:
|
|
333
339
|
|
|
334
340
|
```bash
|
package/cmd/auth.go
CHANGED
|
@@ -4,7 +4,6 @@ import (
|
|
|
4
4
|
"bufio"
|
|
5
5
|
"fmt"
|
|
6
6
|
"os"
|
|
7
|
-
"sort"
|
|
8
7
|
"strings"
|
|
9
8
|
"time"
|
|
10
9
|
|
|
@@ -56,6 +55,21 @@ var authLogoutCmd = &cobra.Command{
|
|
|
56
55
|
RunE: runAuthLogout,
|
|
57
56
|
}
|
|
58
57
|
|
|
58
|
+
var authUseCmd = &cobra.Command{
|
|
59
|
+
Use: "use <credential-id>",
|
|
60
|
+
Short: "Set the active stored credential",
|
|
61
|
+
Args: cobra.ExactArgs(1),
|
|
62
|
+
RunE: runAuthUse,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
var authRemoveCmd = &cobra.Command{
|
|
66
|
+
Use: "remove <credential-id>",
|
|
67
|
+
Aliases: []string{"rm"},
|
|
68
|
+
Short: "Remove one stored credential",
|
|
69
|
+
Args: cobra.ExactArgs(1),
|
|
70
|
+
RunE: runAuthRemove,
|
|
71
|
+
}
|
|
72
|
+
|
|
59
73
|
var authOpenAICmd = &cobra.Command{
|
|
60
74
|
Use: "openai",
|
|
61
75
|
Hidden: true,
|
|
@@ -71,10 +85,14 @@ func init() {
|
|
|
71
85
|
authLoginCmd.Flags().StringVar(&authAPIKeyFlag, "api-key", "", "API key")
|
|
72
86
|
|
|
73
87
|
authLogoutCmd.Flags().StringVarP(&authProviderFlag, "provider", "p", "openai", "Provider id")
|
|
88
|
+
authUseCmd.Flags().StringVarP(&authProviderFlag, "provider", "p", "openai", "Provider id")
|
|
89
|
+
authRemoveCmd.Flags().StringVarP(&authProviderFlag, "provider", "p", "openai", "Provider id")
|
|
74
90
|
|
|
75
91
|
authCmd.AddCommand(authLoginCmd)
|
|
76
92
|
authCmd.AddCommand(authStatusCmd)
|
|
77
93
|
authCmd.AddCommand(authListCmd)
|
|
94
|
+
authCmd.AddCommand(authUseCmd)
|
|
95
|
+
authCmd.AddCommand(authRemoveCmd)
|
|
78
96
|
authCmd.AddCommand(authLogoutCmd)
|
|
79
97
|
authOpenAICmd.AddCommand(newAuthCompatLoginCmd())
|
|
80
98
|
authOpenAICmd.AddCommand(newAuthCompatLogoutCmd())
|
|
@@ -173,6 +191,9 @@ func runAuthLogin(cmd *cobra.Command, args []string) error {
|
|
|
173
191
|
|
|
174
192
|
fmt.Println("Credential saved to .orch/auth.json (0600).")
|
|
175
193
|
fmt.Printf("Auth mode set to api_key. Env %s is still supported with higher priority.\n", cfg.Provider.OpenAI.APIKeyEnv)
|
|
194
|
+
if active, activeErr := auth.Get(cwd, provider); activeErr == nil && active != nil {
|
|
195
|
+
fmt.Printf("Active credential id: %s\n", active.ID)
|
|
196
|
+
}
|
|
176
197
|
return nil
|
|
177
198
|
}
|
|
178
199
|
|
|
@@ -213,6 +234,9 @@ func runAuthLogin(cmd *cobra.Command, args []string) error {
|
|
|
213
234
|
|
|
214
235
|
fmt.Println("Credential saved to .orch/auth.json (0600).")
|
|
215
236
|
fmt.Printf("Auth mode set to account. You can also use %s.\n", cfg.Provider.OpenAI.AccountTokenEnv)
|
|
237
|
+
if active, activeErr := auth.Get(cwd, provider); activeErr == nil && active != nil {
|
|
238
|
+
fmt.Printf("Active credential id: %s\n", active.ID)
|
|
239
|
+
}
|
|
216
240
|
if !result.ExpiresAt.IsZero() {
|
|
217
241
|
fmt.Printf("Token expires at: %s\n", result.ExpiresAt.UTC().Format(time.RFC3339))
|
|
218
242
|
}
|
|
@@ -234,6 +258,10 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
|
|
|
234
258
|
if err != nil {
|
|
235
259
|
return err
|
|
236
260
|
}
|
|
261
|
+
credentials, activeID, err := auth.List(cwd, "openai")
|
|
262
|
+
if err != nil {
|
|
263
|
+
return err
|
|
264
|
+
}
|
|
237
265
|
|
|
238
266
|
fmt.Println("Auth Status")
|
|
239
267
|
fmt.Println("-----------")
|
|
@@ -251,6 +279,10 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
|
|
|
251
279
|
fmt.Printf("stored_api_key: %t\n", storedAPIKey)
|
|
252
280
|
fmt.Printf("stored_account_token: %t\n", storedAccount)
|
|
253
281
|
fmt.Printf("stored_account_refresh: %t\n", storedRefresh)
|
|
282
|
+
fmt.Printf("stored_credentials: %d\n", len(credentials))
|
|
283
|
+
if activeID != "" {
|
|
284
|
+
fmt.Printf("active_credential_id: %s\n", activeID)
|
|
285
|
+
}
|
|
254
286
|
if cred != nil && cred.Type == "oauth" && !cred.ExpiresAt.IsZero() {
|
|
255
287
|
fmt.Printf("account_expires_at: %s\n", cred.ExpiresAt.UTC().Format(time.RFC3339))
|
|
256
288
|
}
|
|
@@ -270,29 +302,76 @@ func runAuthList(cmd *cobra.Command, args []string) error {
|
|
|
270
302
|
return fmt.Errorf("failed to get working directory: %w", err)
|
|
271
303
|
}
|
|
272
304
|
|
|
273
|
-
|
|
305
|
+
provider := resolveProviderArg(args)
|
|
306
|
+
if provider != "openai" {
|
|
307
|
+
return fmt.Errorf("unsupported provider: %s (supported: openai)", provider)
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
credentials, activeID, err := auth.List(cwd, provider)
|
|
274
311
|
if err != nil {
|
|
275
312
|
return err
|
|
276
313
|
}
|
|
277
314
|
|
|
278
315
|
fmt.Println("Stored Credentials")
|
|
279
316
|
fmt.Println("------------------")
|
|
280
|
-
if len(
|
|
317
|
+
if len(credentials) == 0 {
|
|
281
318
|
fmt.Println("No stored credentials found.")
|
|
282
319
|
return nil
|
|
283
320
|
}
|
|
284
321
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
322
|
+
for _, cred := range credentials {
|
|
323
|
+
marker := " "
|
|
324
|
+
if cred.ID == activeID {
|
|
325
|
+
marker = "*"
|
|
326
|
+
}
|
|
327
|
+
line := fmt.Sprintf("%s %s (%s)", marker, cred.ID, cred.Type)
|
|
328
|
+
if cred.Email != "" {
|
|
329
|
+
line += " " + cred.Email
|
|
330
|
+
}
|
|
331
|
+
if cred.AccountID != "" {
|
|
332
|
+
line += " account=" + cred.AccountID
|
|
333
|
+
}
|
|
334
|
+
fmt.Println(line)
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
return nil
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
func runAuthUse(cmd *cobra.Command, args []string) error {
|
|
341
|
+
cwd, err := os.Getwd()
|
|
342
|
+
if err != nil {
|
|
343
|
+
return fmt.Errorf("failed to get working directory: %w", err)
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
provider := resolveProviderArg(nil)
|
|
347
|
+
if provider != "openai" {
|
|
348
|
+
return fmt.Errorf("unsupported provider: %s (supported: openai)", provider)
|
|
349
|
+
}
|
|
350
|
+
credentialID := strings.TrimSpace(args[0])
|
|
351
|
+
if err := auth.SetActive(cwd, provider, credentialID); err != nil {
|
|
352
|
+
return err
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
fmt.Printf("Active credential set to %s for %s.\n", credentialID, provider)
|
|
356
|
+
return nil
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
func runAuthRemove(cmd *cobra.Command, args []string) error {
|
|
360
|
+
cwd, err := os.Getwd()
|
|
361
|
+
if err != nil {
|
|
362
|
+
return fmt.Errorf("failed to get working directory: %w", err)
|
|
288
363
|
}
|
|
289
|
-
sort.Strings(providers)
|
|
290
364
|
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
fmt.
|
|
365
|
+
provider := resolveProviderArg(nil)
|
|
366
|
+
if provider != "openai" {
|
|
367
|
+
return fmt.Errorf("unsupported provider: %s (supported: openai)", provider)
|
|
368
|
+
}
|
|
369
|
+
credentialID := strings.TrimSpace(args[0])
|
|
370
|
+
if err := auth.RemoveCredential(cwd, provider, credentialID); err != nil {
|
|
371
|
+
return err
|
|
294
372
|
}
|
|
295
373
|
|
|
374
|
+
fmt.Printf("Stored credential %s removed for %s.\n", credentialID, provider)
|
|
296
375
|
return nil
|
|
297
376
|
}
|
|
298
377
|
|
package/cmd/auth_test.go
CHANGED
|
@@ -19,15 +19,27 @@ func TestAuthLoginAccountAndLogout(t *testing.T) {
|
|
|
19
19
|
t.Fatalf("save config: %v", err)
|
|
20
20
|
}
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
return auth.OAuthResult{
|
|
22
|
+
results := []auth.OAuthResult{
|
|
23
|
+
{
|
|
25
24
|
AccessToken: "token-123",
|
|
26
25
|
RefreshToken: "refresh-123",
|
|
27
26
|
ExpiresAt: time.Now().UTC().Add(1 * time.Hour),
|
|
28
27
|
AccountID: "acc-123",
|
|
29
28
|
Email: "oauth@example.com",
|
|
30
|
-
},
|
|
29
|
+
},
|
|
30
|
+
{
|
|
31
|
+
AccessToken: "token-456",
|
|
32
|
+
RefreshToken: "refresh-456",
|
|
33
|
+
ExpiresAt: time.Now().UTC().Add(2 * time.Hour),
|
|
34
|
+
AccountID: "acc-456",
|
|
35
|
+
Email: "second@example.com",
|
|
36
|
+
},
|
|
37
|
+
}
|
|
38
|
+
originalOAuthRunner := runOAuthLoginFlow
|
|
39
|
+
runOAuthLoginFlow = func(flow string) (auth.OAuthResult, error) {
|
|
40
|
+
result := results[0]
|
|
41
|
+
results = results[1:]
|
|
42
|
+
return result, nil
|
|
31
43
|
}
|
|
32
44
|
defer func() {
|
|
33
45
|
runOAuthLoginFlow = originalOAuthRunner
|
|
@@ -57,6 +69,47 @@ func TestAuthLoginAccountAndLogout(t *testing.T) {
|
|
|
57
69
|
t.Fatalf("expected stored account id")
|
|
58
70
|
}
|
|
59
71
|
|
|
72
|
+
authEmailFlag = "second@example.com"
|
|
73
|
+
if err := runAuthLogin(nil, nil); err != nil {
|
|
74
|
+
t.Fatalf("auth login second account: %v", err)
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
credentials, activeID, err := auth.List(repoRoot, "openai")
|
|
78
|
+
if err != nil {
|
|
79
|
+
t.Fatalf("list credentials: %v", err)
|
|
80
|
+
}
|
|
81
|
+
if len(credentials) != 2 {
|
|
82
|
+
t.Fatalf("expected 2 credentials, got %d", len(credentials))
|
|
83
|
+
}
|
|
84
|
+
if activeID != "acc-456" {
|
|
85
|
+
t.Fatalf("expected second account to become active, got %s", activeID)
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
if err := runAuthUse(nil, []string{"acc-123"}); err != nil {
|
|
89
|
+
t.Fatalf("auth use: %v", err)
|
|
90
|
+
}
|
|
91
|
+
active, err := auth.Get(repoRoot, "openai")
|
|
92
|
+
if err != nil {
|
|
93
|
+
t.Fatalf("get active credential: %v", err)
|
|
94
|
+
}
|
|
95
|
+
if active == nil || active.ID != "acc-123" {
|
|
96
|
+
t.Fatalf("expected acc-123 active, got %#v", active)
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
if err := runAuthRemove(nil, []string{"acc-456"}); err != nil {
|
|
100
|
+
t.Fatalf("auth remove: %v", err)
|
|
101
|
+
}
|
|
102
|
+
credentials, activeID, err = auth.List(repoRoot, "openai")
|
|
103
|
+
if err != nil {
|
|
104
|
+
t.Fatalf("list credentials after remove: %v", err)
|
|
105
|
+
}
|
|
106
|
+
if len(credentials) != 1 {
|
|
107
|
+
t.Fatalf("expected 1 credential after remove, got %d", len(credentials))
|
|
108
|
+
}
|
|
109
|
+
if activeID != "acc-123" {
|
|
110
|
+
t.Fatalf("expected acc-123 to remain active, got %s", activeID)
|
|
111
|
+
}
|
|
112
|
+
|
|
60
113
|
if err := runAuthLogout(nil, nil); err != nil {
|
|
61
114
|
t.Fatalf("auth logout: %v", err)
|
|
62
115
|
}
|
package/cmd/doctor.go
CHANGED
|
@@ -151,19 +151,27 @@ func errDetail(err error, fallback string) string {
|
|
|
151
151
|
|
|
152
152
|
func newDoctorOpenAIClient(cwd string, cfg config.OpenAIProviderConfig, authMode string, storedCred *auth.Credential) *openai.Client {
|
|
153
153
|
client := openai.New(cfg)
|
|
154
|
+
var accountSession *auth.AccountSession
|
|
155
|
+
if authMode == "account" && strings.TrimSpace(os.Getenv(cfg.AccountTokenEnv)) == "" {
|
|
156
|
+
accountSession = auth.NewAccountSession(cwd, "openai")
|
|
157
|
+
client.SetAccountFailoverHandler(func(ctx context.Context, err error) (string, bool, error) {
|
|
158
|
+
return accountSession.Failover(ctx, openai.AccountFailoverCooldown(err), err.Error())
|
|
159
|
+
})
|
|
160
|
+
client.SetAccountSuccessHandler(func(ctx context.Context) {
|
|
161
|
+
accountSession.MarkSuccess(ctx)
|
|
162
|
+
})
|
|
163
|
+
}
|
|
154
164
|
client.SetTokenResolver(func(ctx context.Context) (string, error) {
|
|
155
|
-
_ = ctx
|
|
156
165
|
if authMode == "api_key" {
|
|
157
166
|
if storedCred != nil && strings.TrimSpace(storedCred.Key) != "" {
|
|
158
167
|
return strings.TrimSpace(storedCred.Key), nil
|
|
159
168
|
}
|
|
160
169
|
return "", nil
|
|
161
170
|
}
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
return "", resolveErr
|
|
171
|
+
if accountSession == nil {
|
|
172
|
+
return "", nil
|
|
165
173
|
}
|
|
166
|
-
return
|
|
174
|
+
return accountSession.ResolveToken(ctx)
|
|
167
175
|
})
|
|
168
176
|
return client
|
|
169
177
|
}
|
package/cmd/interactive.go
CHANGED
|
@@ -961,8 +961,17 @@ func executeChatPrompt(prompt string) (*chatExecutionResult, error) {
|
|
|
961
961
|
}
|
|
962
962
|
|
|
963
963
|
client := openai.New(cfg.Provider.OpenAI)
|
|
964
|
+
var accountSession *auth.AccountSession
|
|
965
|
+
if strings.ToLower(strings.TrimSpace(cfg.Provider.OpenAI.AuthMode)) == "account" && strings.TrimSpace(os.Getenv(cfg.Provider.OpenAI.AccountTokenEnv)) == "" {
|
|
966
|
+
accountSession = auth.NewAccountSession(cwd, "openai")
|
|
967
|
+
client.SetAccountFailoverHandler(func(ctx context.Context, err error) (string, bool, error) {
|
|
968
|
+
return accountSession.Failover(ctx, openai.AccountFailoverCooldown(err), err.Error())
|
|
969
|
+
})
|
|
970
|
+
client.SetAccountSuccessHandler(func(ctx context.Context) {
|
|
971
|
+
accountSession.MarkSuccess(ctx)
|
|
972
|
+
})
|
|
973
|
+
}
|
|
964
974
|
client.SetTokenResolver(func(ctx context.Context) (string, error) {
|
|
965
|
-
_ = ctx
|
|
966
975
|
mode := strings.ToLower(strings.TrimSpace(cfg.Provider.OpenAI.AuthMode))
|
|
967
976
|
if mode == "api_key" {
|
|
968
977
|
cred, credErr := auth.Get(cwd, "openai")
|
|
@@ -974,7 +983,10 @@ func executeChatPrompt(prompt string) (*chatExecutionResult, error) {
|
|
|
974
983
|
}
|
|
975
984
|
return "", nil
|
|
976
985
|
}
|
|
977
|
-
|
|
986
|
+
if accountSession == nil {
|
|
987
|
+
return "", nil
|
|
988
|
+
}
|
|
989
|
+
return accountSession.ResolveToken(ctx)
|
|
978
990
|
})
|
|
979
991
|
|
|
980
992
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.Provider.OpenAI.TimeoutSeconds)*time.Second)
|
package/cmd/version.go
CHANGED
package/internal/auth/account.go
CHANGED
|
@@ -9,12 +9,34 @@ import (
|
|
|
9
9
|
const refreshSkew = 30 * time.Second
|
|
10
10
|
|
|
11
11
|
func ResolveAccountCredential(repoRoot, provider string) (*Credential, error) {
|
|
12
|
+
return resolveAccountCredentialByID(repoRoot, provider, "")
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
func resolveAccountCredentialByID(repoRoot, provider, credentialID string) (*Credential, error) {
|
|
12
16
|
provider = strings.ToLower(strings.TrimSpace(provider))
|
|
13
17
|
if provider == "" {
|
|
14
18
|
return nil, fmt.Errorf("provider is required")
|
|
15
19
|
}
|
|
16
20
|
|
|
17
|
-
|
|
21
|
+
var (
|
|
22
|
+
cred *Credential
|
|
23
|
+
err error
|
|
24
|
+
)
|
|
25
|
+
if strings.TrimSpace(credentialID) == "" {
|
|
26
|
+
cred, err = Get(repoRoot, provider)
|
|
27
|
+
} else {
|
|
28
|
+
credentials, _, listErr := List(repoRoot, provider)
|
|
29
|
+
if listErr != nil {
|
|
30
|
+
return nil, listErr
|
|
31
|
+
}
|
|
32
|
+
for i := range credentials {
|
|
33
|
+
if credentials[i].ID == credentialID {
|
|
34
|
+
copy := credentials[i]
|
|
35
|
+
cred = ©
|
|
36
|
+
break
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
}
|
|
18
40
|
if err != nil {
|
|
19
41
|
return nil, err
|
|
20
42
|
}
|
|
@@ -56,7 +78,10 @@ func ResolveAccountCredential(repoRoot, provider string) (*Credential, error) {
|
|
|
56
78
|
return nil, err
|
|
57
79
|
}
|
|
58
80
|
|
|
59
|
-
|
|
81
|
+
if strings.TrimSpace(credentialID) == "" {
|
|
82
|
+
return Get(repoRoot, provider)
|
|
83
|
+
}
|
|
84
|
+
return resolveAccountCredentialByID(repoRoot, provider, credentialID)
|
|
60
85
|
}
|
|
61
86
|
|
|
62
87
|
func ResolveAccountAccessToken(repoRoot, provider string) (string, error) {
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
package auth
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"fmt"
|
|
6
|
+
"strings"
|
|
7
|
+
"time"
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
type AccountSession struct {
|
|
11
|
+
repoRoot string
|
|
12
|
+
provider string
|
|
13
|
+
currentID string
|
|
14
|
+
excluded map[string]struct{}
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
func NewAccountSession(repoRoot, provider string) *AccountSession {
|
|
18
|
+
return &AccountSession{
|
|
19
|
+
repoRoot: strings.TrimSpace(repoRoot),
|
|
20
|
+
provider: strings.ToLower(strings.TrimSpace(provider)),
|
|
21
|
+
excluded: map[string]struct{}{},
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
func (s *AccountSession) ResolveToken(ctx context.Context) (string, error) {
|
|
26
|
+
_ = ctx
|
|
27
|
+
cred, err := s.pickCredential()
|
|
28
|
+
if err != nil {
|
|
29
|
+
return "", err
|
|
30
|
+
}
|
|
31
|
+
if cred == nil {
|
|
32
|
+
return "", fmt.Errorf("no active oauth credential available for provider %s", s.provider)
|
|
33
|
+
}
|
|
34
|
+
s.currentID = cred.ID
|
|
35
|
+
return strings.TrimSpace(cred.AccessToken), nil
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
func (s *AccountSession) Failover(ctx context.Context, cooldown time.Duration, reason string) (string, bool, error) {
|
|
39
|
+
_ = ctx
|
|
40
|
+
if strings.TrimSpace(s.currentID) == "" {
|
|
41
|
+
return "", false, nil
|
|
42
|
+
}
|
|
43
|
+
if err := mutateCredential(s.repoRoot, s.provider, s.currentID, func(cred *Credential) error {
|
|
44
|
+
if cooldown > 0 {
|
|
45
|
+
cred.CooldownUntil = time.Now().UTC().Add(cooldown)
|
|
46
|
+
}
|
|
47
|
+
cred.LastError = strings.TrimSpace(reason)
|
|
48
|
+
cred.UpdatedAt = time.Now().UTC()
|
|
49
|
+
return nil
|
|
50
|
+
}); err != nil {
|
|
51
|
+
return "", false, err
|
|
52
|
+
}
|
|
53
|
+
s.excluded[s.currentID] = struct{}{}
|
|
54
|
+
|
|
55
|
+
cred, err := s.pickCredential()
|
|
56
|
+
if err != nil {
|
|
57
|
+
return "", false, err
|
|
58
|
+
}
|
|
59
|
+
if cred == nil {
|
|
60
|
+
return "", false, nil
|
|
61
|
+
}
|
|
62
|
+
s.currentID = cred.ID
|
|
63
|
+
return strings.TrimSpace(cred.AccessToken), true, nil
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
func (s *AccountSession) MarkSuccess(ctx context.Context) {
|
|
67
|
+
_ = ctx
|
|
68
|
+
if strings.TrimSpace(s.currentID) == "" {
|
|
69
|
+
return
|
|
70
|
+
}
|
|
71
|
+
_ = mutateCredential(s.repoRoot, s.provider, s.currentID, func(cred *Credential) error {
|
|
72
|
+
cred.LastError = ""
|
|
73
|
+
cred.CooldownUntil = time.Time{}
|
|
74
|
+
cred.LastUsedAt = time.Now().UTC()
|
|
75
|
+
cred.UpdatedAt = time.Now().UTC()
|
|
76
|
+
return nil
|
|
77
|
+
})
|
|
78
|
+
s.excluded = map[string]struct{}{}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
func (s *AccountSession) pickCredential() (*Credential, error) {
|
|
82
|
+
credentials, activeID, err := List(s.repoRoot, s.provider)
|
|
83
|
+
if err != nil {
|
|
84
|
+
return nil, err
|
|
85
|
+
}
|
|
86
|
+
if len(credentials) == 0 {
|
|
87
|
+
return nil, nil
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
now := time.Now().UTC()
|
|
91
|
+
ordered := orderCredentials(credentials, activeID, s.currentID)
|
|
92
|
+
for _, candidate := range ordered {
|
|
93
|
+
if candidate.Type != "oauth" {
|
|
94
|
+
continue
|
|
95
|
+
}
|
|
96
|
+
if _, skip := s.excluded[candidate.ID]; skip {
|
|
97
|
+
continue
|
|
98
|
+
}
|
|
99
|
+
if !candidate.CooldownUntil.IsZero() && candidate.CooldownUntil.After(now) {
|
|
100
|
+
continue
|
|
101
|
+
}
|
|
102
|
+
if candidate.ID != activeID {
|
|
103
|
+
if err := SetActive(s.repoRoot, s.provider, candidate.ID); err != nil {
|
|
104
|
+
return nil, err
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
resolved, err := resolveAccountCredentialByID(s.repoRoot, s.provider, candidate.ID)
|
|
108
|
+
if err == nil {
|
|
109
|
+
return resolved, nil
|
|
110
|
+
}
|
|
111
|
+
_ = mutateCredential(s.repoRoot, s.provider, candidate.ID, func(cred *Credential) error {
|
|
112
|
+
cred.LastError = strings.TrimSpace(err.Error())
|
|
113
|
+
cred.CooldownUntil = now.Add(5 * time.Minute)
|
|
114
|
+
cred.UpdatedAt = now
|
|
115
|
+
return nil
|
|
116
|
+
})
|
|
117
|
+
s.excluded[candidate.ID] = struct{}{}
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
return nil, nil
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
func orderCredentials(credentials []Credential, activeID, currentID string) []Credential {
|
|
124
|
+
ordered := make([]Credential, 0, len(credentials))
|
|
125
|
+
appendByID := func(id string) {
|
|
126
|
+
if strings.TrimSpace(id) == "" {
|
|
127
|
+
return
|
|
128
|
+
}
|
|
129
|
+
for _, cred := range credentials {
|
|
130
|
+
if cred.ID == id && !containsCredential(ordered, id) {
|
|
131
|
+
ordered = append(ordered, cred)
|
|
132
|
+
return
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
appendByID(currentID)
|
|
137
|
+
appendByID(activeID)
|
|
138
|
+
for _, cred := range credentials {
|
|
139
|
+
if containsCredential(ordered, cred.ID) {
|
|
140
|
+
continue
|
|
141
|
+
}
|
|
142
|
+
ordered = append(ordered, cred)
|
|
143
|
+
}
|
|
144
|
+
return ordered
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
func containsCredential(credentials []Credential, credentialID string) bool {
|
|
148
|
+
for _, cred := range credentials {
|
|
149
|
+
if cred.ID == credentialID {
|
|
150
|
+
return true
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
return false
|
|
154
|
+
}
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
package auth
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"encoding/base64"
|
|
6
|
+
"fmt"
|
|
7
|
+
"testing"
|
|
8
|
+
"time"
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
func TestAccountSessionFailsOverToNextCredential(t *testing.T) {
|
|
12
|
+
repoRoot := t.TempDir()
|
|
13
|
+
if err := Set(repoRoot, "openai", Credential{Type: "oauth", AccessToken: testSessionAccountToken("acc-1"), RefreshToken: "refresh-1", AccountID: "acc-1", Email: "one@example.com"}); err != nil {
|
|
14
|
+
t.Fatalf("set first account: %v", err)
|
|
15
|
+
}
|
|
16
|
+
if err := Set(repoRoot, "openai", Credential{Type: "oauth", AccessToken: testSessionAccountToken("acc-2"), RefreshToken: "refresh-2", AccountID: "acc-2", Email: "two@example.com"}); err != nil {
|
|
17
|
+
t.Fatalf("set second account: %v", err)
|
|
18
|
+
}
|
|
19
|
+
if err := SetActive(repoRoot, "openai", "acc-1"); err != nil {
|
|
20
|
+
t.Fatalf("set active: %v", err)
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
session := NewAccountSession(repoRoot, "openai")
|
|
24
|
+
token, err := session.ResolveToken(context.Background())
|
|
25
|
+
if err != nil {
|
|
26
|
+
t.Fatalf("resolve token: %v", err)
|
|
27
|
+
}
|
|
28
|
+
if token != testSessionAccountToken("acc-1") {
|
|
29
|
+
t.Fatalf("expected first token, got %q", token)
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
nextToken, switched, err := session.Failover(context.Background(), time.Minute, "rate limited")
|
|
33
|
+
if err != nil {
|
|
34
|
+
t.Fatalf("failover: %v", err)
|
|
35
|
+
}
|
|
36
|
+
if !switched {
|
|
37
|
+
t.Fatalf("expected failover switch")
|
|
38
|
+
}
|
|
39
|
+
if nextToken != testSessionAccountToken("acc-2") {
|
|
40
|
+
t.Fatalf("expected second token, got %q", nextToken)
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
active, err := Get(repoRoot, "openai")
|
|
44
|
+
if err != nil {
|
|
45
|
+
t.Fatalf("get active credential: %v", err)
|
|
46
|
+
}
|
|
47
|
+
if active == nil || active.ID != "acc-2" {
|
|
48
|
+
t.Fatalf("expected acc-2 active after failover, got %#v", active)
|
|
49
|
+
}
|
|
50
|
+
credentials, _, err := List(repoRoot, "openai")
|
|
51
|
+
if err != nil {
|
|
52
|
+
t.Fatalf("list credentials: %v", err)
|
|
53
|
+
}
|
|
54
|
+
var first *Credential
|
|
55
|
+
for i := range credentials {
|
|
56
|
+
if credentials[i].ID == "acc-1" {
|
|
57
|
+
first = &credentials[i]
|
|
58
|
+
break
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
if first == nil || first.CooldownUntil.IsZero() {
|
|
62
|
+
t.Fatalf("expected first credential to have cooldown set, got %#v", first)
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
func testSessionAccountToken(accountID string) string {
|
|
67
|
+
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none"}`))
|
|
68
|
+
payload := fmt.Sprintf(`{"https://api.openai.com/auth":{"chatgpt_account_id":"%s"}}`, accountID)
|
|
69
|
+
body := base64.RawURLEncoding.EncodeToString([]byte(payload))
|
|
70
|
+
return header + "." + body + ".sig"
|
|
71
|
+
}
|