orch-code 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +12 -0
- package/LICENSE +21 -0
- package/README.md +624 -0
- package/cmd/apply.go +111 -0
- package/cmd/auth.go +393 -0
- package/cmd/auth_test.go +100 -0
- package/cmd/diff.go +57 -0
- package/cmd/doctor.go +149 -0
- package/cmd/explain.go +192 -0
- package/cmd/explain_test.go +62 -0
- package/cmd/init.go +100 -0
- package/cmd/interactive.go +1372 -0
- package/cmd/interactive_input.go +45 -0
- package/cmd/interactive_input_test.go +55 -0
- package/cmd/logs.go +72 -0
- package/cmd/model.go +84 -0
- package/cmd/plan.go +149 -0
- package/cmd/provider.go +189 -0
- package/cmd/provider_model_doctor_test.go +91 -0
- package/cmd/root.go +67 -0
- package/cmd/run.go +123 -0
- package/cmd/run_engine.go +208 -0
- package/cmd/run_engine_test.go +30 -0
- package/cmd/session.go +589 -0
- package/cmd/session_helpers.go +54 -0
- package/cmd/session_integration_test.go +30 -0
- package/cmd/session_list_current_test.go +87 -0
- package/cmd/session_messages_test.go +163 -0
- package/cmd/session_runs_test.go +68 -0
- package/cmd/sprint1_integration_test.go +119 -0
- package/cmd/stats.go +173 -0
- package/cmd/stats_test.go +71 -0
- package/cmd/version.go +4 -0
- package/go.mod +45 -0
- package/go.sum +108 -0
- package/internal/agents/agent.go +31 -0
- package/internal/agents/coder.go +167 -0
- package/internal/agents/planner.go +155 -0
- package/internal/agents/reviewer.go +118 -0
- package/internal/agents/runtime.go +25 -0
- package/internal/agents/runtime_test.go +77 -0
- package/internal/auth/account.go +78 -0
- package/internal/auth/oauth.go +523 -0
- package/internal/auth/store.go +287 -0
- package/internal/confidence/policy.go +174 -0
- package/internal/confidence/policy_test.go +71 -0
- package/internal/confidence/scorer.go +253 -0
- package/internal/confidence/scorer_test.go +83 -0
- package/internal/config/config.go +331 -0
- package/internal/config/config_defaults_test.go +138 -0
- package/internal/execution/contract_builder.go +160 -0
- package/internal/execution/contract_builder_test.go +68 -0
- package/internal/execution/plan_compliance.go +161 -0
- package/internal/execution/plan_compliance_test.go +71 -0
- package/internal/execution/retry_directive.go +132 -0
- package/internal/execution/scope_guard.go +69 -0
- package/internal/logger/logger.go +120 -0
- package/internal/models/contracts_test.go +100 -0
- package/internal/models/models.go +269 -0
- package/internal/orchestrator/orchestrator.go +701 -0
- package/internal/orchestrator/orchestrator_retry_test.go +135 -0
- package/internal/orchestrator/review_engine_test.go +50 -0
- package/internal/orchestrator/state.go +42 -0
- package/internal/orchestrator/test_classifier_test.go +68 -0
- package/internal/patch/applier.go +131 -0
- package/internal/patch/applier_test.go +25 -0
- package/internal/patch/parser.go +89 -0
- package/internal/patch/patch.go +60 -0
- package/internal/patch/summary.go +30 -0
- package/internal/patch/validator.go +104 -0
- package/internal/planning/normalizer.go +416 -0
- package/internal/planning/normalizer_test.go +64 -0
- package/internal/providers/errors.go +35 -0
- package/internal/providers/openai/client.go +498 -0
- package/internal/providers/openai/client_test.go +187 -0
- package/internal/providers/provider.go +47 -0
- package/internal/providers/registry.go +32 -0
- package/internal/providers/registry_test.go +57 -0
- package/internal/providers/router.go +52 -0
- package/internal/providers/state.go +114 -0
- package/internal/providers/state_test.go +64 -0
- package/internal/repo/analyzer.go +188 -0
- package/internal/repo/context.go +83 -0
- package/internal/review/engine.go +267 -0
- package/internal/review/engine_test.go +103 -0
- package/internal/runstore/store.go +137 -0
- package/internal/runstore/store_test.go +59 -0
- package/internal/runtime/lock.go +150 -0
- package/internal/runtime/lock_test.go +57 -0
- package/internal/session/compaction.go +260 -0
- package/internal/session/compaction_test.go +36 -0
- package/internal/session/service.go +117 -0
- package/internal/session/service_test.go +113 -0
- package/internal/storage/storage.go +1498 -0
- package/internal/storage/storage_test.go +413 -0
- package/internal/testing/classifier.go +80 -0
- package/internal/testing/classifier_test.go +36 -0
- package/internal/tools/command.go +160 -0
- package/internal/tools/command_test.go +56 -0
- package/internal/tools/file.go +111 -0
- package/internal/tools/git.go +77 -0
- package/internal/tools/invalid_params_test.go +36 -0
- package/internal/tools/policy.go +98 -0
- package/internal/tools/policy_test.go +36 -0
- package/internal/tools/registry_test.go +52 -0
- package/internal/tools/result.go +30 -0
- package/internal/tools/search.go +86 -0
- package/internal/tools/tool.go +94 -0
- package/main.go +9 -0
- package/npm/orch.js +25 -0
- package/package.json +41 -0
- package/scripts/changelog.js +20 -0
- package/scripts/check-release-version.js +21 -0
- package/scripts/lib/release-utils.js +223 -0
- package/scripts/postinstall.js +157 -0
- package/scripts/release.js +52 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
package agents
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"testing"
|
|
6
|
+
"time"
|
|
7
|
+
|
|
8
|
+
"github.com/furkanbeydemir/orch/internal/config"
|
|
9
|
+
"github.com/furkanbeydemir/orch/internal/models"
|
|
10
|
+
"github.com/furkanbeydemir/orch/internal/providers"
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
type providerStub struct {
|
|
14
|
+
name string
|
|
15
|
+
text string
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
func (p providerStub) Name() string { return p.name }
|
|
19
|
+
|
|
20
|
+
func (p providerStub) Validate(ctx context.Context) error { return nil }
|
|
21
|
+
|
|
22
|
+
func (p providerStub) Chat(ctx context.Context, req providers.ChatRequest) (providers.ChatResponse, error) {
|
|
23
|
+
return providers.ChatResponse{Text: p.text, FinishReason: "completed"}, nil
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
func (p providerStub) Stream(ctx context.Context, req providers.ChatRequest) (<-chan providers.StreamEvent, <-chan error) {
|
|
27
|
+
ev := make(chan providers.StreamEvent)
|
|
28
|
+
err := make(chan error)
|
|
29
|
+
close(ev)
|
|
30
|
+
close(err)
|
|
31
|
+
return ev, err
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
func TestPlannerUsesProviderRuntime(t *testing.T) {
|
|
35
|
+
cfg := config.DefaultConfig()
|
|
36
|
+
reg := providers.NewRegistry()
|
|
37
|
+
reg.Register(providerStub{name: "openai", text: "Plan from provider"})
|
|
38
|
+
router := providers.NewRouter(cfg, reg)
|
|
39
|
+
|
|
40
|
+
planner := NewPlanner("gpt-5.3-codex")
|
|
41
|
+
planner.SetRuntime(&LLMRuntime{Router: router})
|
|
42
|
+
|
|
43
|
+
output, err := planner.Execute(&Input{Task: &models.Task{ID: "t1", Description: "demo", CreatedAt: time.Now()}})
|
|
44
|
+
if err != nil {
|
|
45
|
+
t.Fatalf("planner execute: %v", err)
|
|
46
|
+
}
|
|
47
|
+
if output == nil || output.Plan == nil || len(output.Plan.Steps) == 0 {
|
|
48
|
+
t.Fatalf("expected plan output")
|
|
49
|
+
}
|
|
50
|
+
if output.Plan.Steps[0].Description != "Plan from provider" {
|
|
51
|
+
t.Fatalf("unexpected planner description: %q", output.Plan.Steps[0].Description)
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
func TestReviewerParsesReviseDecision(t *testing.T) {
|
|
56
|
+
cfg := config.DefaultConfig()
|
|
57
|
+
reg := providers.NewRegistry()
|
|
58
|
+
reg.Register(providerStub{name: "openai", text: "revise: missing tests"})
|
|
59
|
+
router := providers.NewRouter(cfg, reg)
|
|
60
|
+
|
|
61
|
+
reviewer := NewReviewer("gpt-5.3-codex")
|
|
62
|
+
reviewer.SetRuntime(&LLMRuntime{Router: router})
|
|
63
|
+
|
|
64
|
+
output, err := reviewer.Execute(&Input{
|
|
65
|
+
Task: &models.Task{ID: "t1", Description: "demo", CreatedAt: time.Now()},
|
|
66
|
+
Patch: &models.Patch{TaskID: "t1", RawDiff: ""},
|
|
67
|
+
})
|
|
68
|
+
if err != nil {
|
|
69
|
+
t.Fatalf("reviewer execute: %v", err)
|
|
70
|
+
}
|
|
71
|
+
if output == nil || output.Review == nil {
|
|
72
|
+
t.Fatalf("expected review output")
|
|
73
|
+
}
|
|
74
|
+
if output.Review.Decision != models.ReviewRevise {
|
|
75
|
+
t.Fatalf("expected revise decision")
|
|
76
|
+
}
|
|
77
|
+
}
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
package auth
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"fmt"
|
|
5
|
+
"strings"
|
|
6
|
+
"time"
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
const refreshSkew = 30 * time.Second
|
|
10
|
+
|
|
11
|
+
func ResolveAccountCredential(repoRoot, provider string) (*Credential, error) {
|
|
12
|
+
provider = strings.ToLower(strings.TrimSpace(provider))
|
|
13
|
+
if provider == "" {
|
|
14
|
+
return nil, fmt.Errorf("provider is required")
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
cred, err := Get(repoRoot, provider)
|
|
18
|
+
if err != nil {
|
|
19
|
+
return nil, err
|
|
20
|
+
}
|
|
21
|
+
if cred == nil {
|
|
22
|
+
return nil, fmt.Errorf("no stored credential for provider %s", provider)
|
|
23
|
+
}
|
|
24
|
+
if strings.ToLower(strings.TrimSpace(cred.Type)) != "oauth" {
|
|
25
|
+
return nil, fmt.Errorf("stored credential for %s is not oauth", provider)
|
|
26
|
+
}
|
|
27
|
+
if strings.TrimSpace(cred.AccessToken) == "" {
|
|
28
|
+
return nil, fmt.Errorf("stored oauth access token is empty for %s", provider)
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
if !shouldRefresh(cred) {
|
|
32
|
+
return cred, nil
|
|
33
|
+
}
|
|
34
|
+
if strings.TrimSpace(cred.RefreshToken) == "" {
|
|
35
|
+
return nil, fmt.Errorf("oauth access token expired and no refresh token is available")
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
refreshed, err := RefreshOAuthToken(cred.RefreshToken)
|
|
39
|
+
if err != nil {
|
|
40
|
+
return nil, fmt.Errorf("failed to refresh oauth token: %w", err)
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
cred.AccessToken = strings.TrimSpace(refreshed.AccessToken)
|
|
44
|
+
if strings.TrimSpace(refreshed.RefreshToken) != "" {
|
|
45
|
+
cred.RefreshToken = strings.TrimSpace(refreshed.RefreshToken)
|
|
46
|
+
}
|
|
47
|
+
cred.ExpiresAt = refreshed.ExpiresAt
|
|
48
|
+
if strings.TrimSpace(refreshed.AccountID) != "" {
|
|
49
|
+
cred.AccountID = strings.TrimSpace(refreshed.AccountID)
|
|
50
|
+
}
|
|
51
|
+
if strings.TrimSpace(refreshed.Email) != "" {
|
|
52
|
+
cred.Email = strings.TrimSpace(refreshed.Email)
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
if err := Set(repoRoot, provider, *cred); err != nil {
|
|
56
|
+
return nil, err
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
return Get(repoRoot, provider)
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
func ResolveAccountAccessToken(repoRoot, provider string) (string, error) {
|
|
63
|
+
cred, err := ResolveAccountCredential(repoRoot, provider)
|
|
64
|
+
if err != nil {
|
|
65
|
+
return "", err
|
|
66
|
+
}
|
|
67
|
+
return strings.TrimSpace(cred.AccessToken), nil
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
func shouldRefresh(cred *Credential) bool {
|
|
71
|
+
if cred == nil {
|
|
72
|
+
return false
|
|
73
|
+
}
|
|
74
|
+
if cred.ExpiresAt.IsZero() {
|
|
75
|
+
return false
|
|
76
|
+
}
|
|
77
|
+
return time.Now().UTC().After(cred.ExpiresAt.Add(-refreshSkew))
|
|
78
|
+
}
|
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
package auth
|
|
2
|
+
|
|
3
|
+
import (
|
|
4
|
+
"context"
|
|
5
|
+
"crypto/rand"
|
|
6
|
+
"crypto/sha256"
|
|
7
|
+
"encoding/base64"
|
|
8
|
+
"encoding/json"
|
|
9
|
+
"fmt"
|
|
10
|
+
"io"
|
|
11
|
+
"net/http"
|
|
12
|
+
"net/url"
|
|
13
|
+
"os/exec"
|
|
14
|
+
"runtime"
|
|
15
|
+
"strconv"
|
|
16
|
+
"strings"
|
|
17
|
+
"time"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
const (
|
|
21
|
+
clientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
|
22
|
+
authURL = "https://auth.openai.com/oauth/authorize"
|
|
23
|
+
tokenURL = "https://auth.openai.com/oauth/token"
|
|
24
|
+
deviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode"
|
|
25
|
+
deviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token"
|
|
26
|
+
deviceRedirectURI = "https://auth.openai.com/deviceauth/callback"
|
|
27
|
+
redirectURI = "http://localhost:1455/auth/callback"
|
|
28
|
+
callbackHost = "localhost:1455"
|
|
29
|
+
defaultTokenExpirySeconds = 3600
|
|
30
|
+
oauthCallbackWaitTimeout = 5 * time.Minute
|
|
31
|
+
headlessPollingSafetyMargin = 3 * time.Second
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
type OAuthResult struct {
|
|
35
|
+
AccessToken string
|
|
36
|
+
RefreshToken string
|
|
37
|
+
ExpiresAt time.Time
|
|
38
|
+
AccountID string
|
|
39
|
+
Email string
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
// RunOAuthFlow executes OpenAI OAuth login.
|
|
43
|
+
// Supported flows: auto, browser, headless.
|
|
44
|
+
func RunOAuthFlow(flow string) (OAuthResult, error) {
|
|
45
|
+
normalized := strings.ToLower(strings.TrimSpace(flow))
|
|
46
|
+
if normalized == "" {
|
|
47
|
+
normalized = "auto"
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
switch normalized {
|
|
51
|
+
case "browser":
|
|
52
|
+
return runBrowserOAuthFlow()
|
|
53
|
+
case "headless":
|
|
54
|
+
return runHeadlessOAuthFlow()
|
|
55
|
+
case "auto":
|
|
56
|
+
browser, browserErr := runBrowserOAuthFlow()
|
|
57
|
+
if browserErr == nil {
|
|
58
|
+
return browser, nil
|
|
59
|
+
}
|
|
60
|
+
fmt.Printf("\nBrowser login failed: %v\n", browserErr)
|
|
61
|
+
fmt.Println("Falling back to headless device login...")
|
|
62
|
+
headless, headlessErr := runHeadlessOAuthFlow()
|
|
63
|
+
if headlessErr != nil {
|
|
64
|
+
return OAuthResult{}, fmt.Errorf("browser flow failed: %v; headless flow failed: %w", browserErr, headlessErr)
|
|
65
|
+
}
|
|
66
|
+
return headless, nil
|
|
67
|
+
default:
|
|
68
|
+
return OAuthResult{}, fmt.Errorf("unsupported oauth flow: %s", flow)
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
func RefreshOAuthToken(refreshToken string) (OAuthResult, error) {
|
|
73
|
+
refreshToken = strings.TrimSpace(refreshToken)
|
|
74
|
+
if refreshToken == "" {
|
|
75
|
+
return OAuthResult{}, fmt.Errorf("refresh token is required")
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
tokens, err := requestToken(url.Values{
|
|
79
|
+
"grant_type": {"refresh_token"},
|
|
80
|
+
"refresh_token": {refreshToken},
|
|
81
|
+
"client_id": {clientID},
|
|
82
|
+
})
|
|
83
|
+
if err != nil {
|
|
84
|
+
return OAuthResult{}, err
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
return parseOAuthResult(tokens)
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
func runBrowserOAuthFlow() (OAuthResult, error) {
|
|
91
|
+
// Generate state and PKCE challenge
|
|
92
|
+
state := generateRandomString(32)
|
|
93
|
+
verifier := generateRandomString(64)
|
|
94
|
+
challenge := generateCodeChallenge(verifier)
|
|
95
|
+
|
|
96
|
+
// Build auth URL.
|
|
97
|
+
params := url.Values{}
|
|
98
|
+
params.Add("response_type", "code")
|
|
99
|
+
params.Add("client_id", clientID)
|
|
100
|
+
params.Add("redirect_uri", redirectURI)
|
|
101
|
+
params.Add("scope", "openid profile email offline_access")
|
|
102
|
+
params.Add("state", state)
|
|
103
|
+
params.Add("code_challenge", challenge)
|
|
104
|
+
params.Add("code_challenge_method", "S256")
|
|
105
|
+
params.Add("id_token_add_organizations", "true")
|
|
106
|
+
params.Add("codex_cli_simplified_flow", "true")
|
|
107
|
+
|
|
108
|
+
// Identify CLI origin.
|
|
109
|
+
params.Add("originator", "orch")
|
|
110
|
+
|
|
111
|
+
loginURL := fmt.Sprintf("%s?%s", authURL, params.Encode())
|
|
112
|
+
|
|
113
|
+
// Start local server to receive callback.
|
|
114
|
+
addr := callbackHost
|
|
115
|
+
mux := http.NewServeMux()
|
|
116
|
+
srv := &http.Server{
|
|
117
|
+
Addr: addr,
|
|
118
|
+
Handler: mux,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
codeChan := make(chan string, 1)
|
|
122
|
+
errChan := make(chan error, 1)
|
|
123
|
+
|
|
124
|
+
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
|
125
|
+
query := r.URL.Query()
|
|
126
|
+
|
|
127
|
+
if errDesc := query.Get("error_description"); errDesc != "" {
|
|
128
|
+
select {
|
|
129
|
+
case errChan <- fmt.Errorf("auth error: %s", errDesc):
|
|
130
|
+
default:
|
|
131
|
+
}
|
|
132
|
+
fmt.Fprintf(w, "Auth error: %s. You can close this window.", errDesc)
|
|
133
|
+
return
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
if returnedState := query.Get("state"); returnedState != state {
|
|
137
|
+
select {
|
|
138
|
+
case errChan <- fmt.Errorf("state mismatch: expected %s, got %s", state, returnedState):
|
|
139
|
+
default:
|
|
140
|
+
}
|
|
141
|
+
fmt.Fprintln(w, "State mismatch error. You can close this window.")
|
|
142
|
+
return
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
code := query.Get("code")
|
|
146
|
+
if code == "" {
|
|
147
|
+
select {
|
|
148
|
+
case errChan <- fmt.Errorf("no code returned"):
|
|
149
|
+
default:
|
|
150
|
+
}
|
|
151
|
+
fmt.Fprintln(w, "No code provided. You can close this window.")
|
|
152
|
+
return
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
// Success.
|
|
156
|
+
fmt.Fprintln(w, "Login successful! You can close this window and return to Orch.")
|
|
157
|
+
select {
|
|
158
|
+
case codeChan <- code:
|
|
159
|
+
default:
|
|
160
|
+
}
|
|
161
|
+
})
|
|
162
|
+
|
|
163
|
+
go func() {
|
|
164
|
+
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
165
|
+
select {
|
|
166
|
+
case errChan <- fmt.Errorf("could not start local server on %s: %w", addr, err):
|
|
167
|
+
default:
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
}()
|
|
171
|
+
|
|
172
|
+
fmt.Println("\nLogin to ChatGPT Plus/Pro/Codex Subscription")
|
|
173
|
+
fmt.Printf("\n%s\n\n", loginURL)
|
|
174
|
+
if err := openBrowser(loginURL); err != nil {
|
|
175
|
+
fmt.Printf("Could not open browser automatically: %v\n", err)
|
|
176
|
+
}
|
|
177
|
+
fmt.Println("Ctrl+click to open if needed")
|
|
178
|
+
fmt.Println("\nA browser window should open. Complete login to finish.")
|
|
179
|
+
fmt.Println("\nWaiting for browser callback... (Press Ctrl+C to cancel)")
|
|
180
|
+
|
|
181
|
+
var code string
|
|
182
|
+
select {
|
|
183
|
+
case c := <-codeChan:
|
|
184
|
+
code = c
|
|
185
|
+
case err := <-errChan:
|
|
186
|
+
_ = srv.Shutdown(context.Background())
|
|
187
|
+
return OAuthResult{}, err
|
|
188
|
+
case <-time.After(oauthCallbackWaitTimeout):
|
|
189
|
+
_ = srv.Shutdown(context.Background())
|
|
190
|
+
return OAuthResult{}, fmt.Errorf("oauth callback timed out after %s", oauthCallbackWaitTimeout)
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
// Exchange code for token.
|
|
194
|
+
tokens, err := exchangeCodeForToken(code, verifier, redirectURI)
|
|
195
|
+
|
|
196
|
+
// Shut down server.
|
|
197
|
+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
198
|
+
defer cancel()
|
|
199
|
+
_ = srv.Shutdown(ctx)
|
|
200
|
+
|
|
201
|
+
if err != nil {
|
|
202
|
+
return OAuthResult{}, fmt.Errorf("failed to exchange code for token: %w", err)
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
return parseOAuthResult(tokens)
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
func runHeadlessOAuthFlow() (OAuthResult, error) {
|
|
209
|
+
deviceAuth, err := requestDeviceAuth()
|
|
210
|
+
if err != nil {
|
|
211
|
+
return OAuthResult{}, err
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
fmt.Println("\nHeadless login to ChatGPT Plus/Pro/Codex Subscription")
|
|
215
|
+
fmt.Println("Open: https://auth.openai.com/codex/device")
|
|
216
|
+
fmt.Printf("Enter code: %s\n", deviceAuth.UserCode)
|
|
217
|
+
fmt.Println("Waiting for authorization... (Press Ctrl+C to cancel)")
|
|
218
|
+
|
|
219
|
+
code, verifier, err := pollForDeviceAuthorization(deviceAuth)
|
|
220
|
+
if err != nil {
|
|
221
|
+
return OAuthResult{}, err
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
tokens, err := exchangeCodeForToken(code, verifier, deviceRedirectURI)
|
|
225
|
+
if err != nil {
|
|
226
|
+
return OAuthResult{}, fmt.Errorf("failed to exchange headless authorization code: %w", err)
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
return parseOAuthResult(tokens)
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
func exchangeCodeForToken(code, verifier, redirect string) (*tokenResponse, error) {
|
|
233
|
+
data := url.Values{}
|
|
234
|
+
data.Set("grant_type", "authorization_code")
|
|
235
|
+
data.Set("client_id", clientID)
|
|
236
|
+
data.Set("code", code)
|
|
237
|
+
data.Set("redirect_uri", redirect)
|
|
238
|
+
data.Set("code_verifier", verifier)
|
|
239
|
+
return requestToken(data)
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
func requestToken(data url.Values) (*tokenResponse, error) {
|
|
243
|
+
|
|
244
|
+
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(data.Encode()))
|
|
245
|
+
if err != nil {
|
|
246
|
+
return nil, err
|
|
247
|
+
}
|
|
248
|
+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
249
|
+
req.Header.Set("Accept", "application/json")
|
|
250
|
+
|
|
251
|
+
client := &http.Client{Timeout: 15 * time.Second}
|
|
252
|
+
resp, err := client.Do(req)
|
|
253
|
+
if err != nil {
|
|
254
|
+
return nil, err
|
|
255
|
+
}
|
|
256
|
+
defer resp.Body.Close()
|
|
257
|
+
|
|
258
|
+
body, err := io.ReadAll(resp.Body)
|
|
259
|
+
if err != nil {
|
|
260
|
+
return nil, err
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
264
|
+
return nil, fmt.Errorf("token request failed (status %d): %s", resp.StatusCode, string(body))
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
var result tokenResponse
|
|
268
|
+
if err := json.Unmarshal(body, &result); err != nil {
|
|
269
|
+
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
if result.AccessToken == "" {
|
|
273
|
+
return nil, fmt.Errorf("no access token in response: %s", string(body))
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
return &result, nil
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
type tokenResponse struct {
|
|
280
|
+
AccessToken string `json:"access_token"`
|
|
281
|
+
RefreshToken string `json:"refresh_token"`
|
|
282
|
+
IDToken string `json:"id_token"`
|
|
283
|
+
ExpiresIn int `json:"expires_in"`
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
func parseOAuthResult(tokens *tokenResponse) (OAuthResult, error) {
|
|
287
|
+
if tokens == nil {
|
|
288
|
+
return OAuthResult{}, fmt.Errorf("empty oauth token response")
|
|
289
|
+
}
|
|
290
|
+
if strings.TrimSpace(tokens.AccessToken) == "" {
|
|
291
|
+
return OAuthResult{}, fmt.Errorf("oauth access token is empty")
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
expiresIn := tokens.ExpiresIn
|
|
295
|
+
if expiresIn <= 0 {
|
|
296
|
+
expiresIn = defaultTokenExpirySeconds
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
result := OAuthResult{
|
|
300
|
+
AccessToken: strings.TrimSpace(tokens.AccessToken),
|
|
301
|
+
RefreshToken: strings.TrimSpace(tokens.RefreshToken),
|
|
302
|
+
ExpiresAt: time.Now().UTC().Add(time.Duration(expiresIn) * time.Second),
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
if idClaims := decodeJWTClaims(tokens.IDToken); idClaims != nil {
|
|
306
|
+
result.AccountID = extractAccountID(idClaims)
|
|
307
|
+
result.Email = extractEmail(idClaims)
|
|
308
|
+
}
|
|
309
|
+
if result.AccountID == "" || result.Email == "" {
|
|
310
|
+
if accessClaims := decodeJWTClaims(tokens.AccessToken); accessClaims != nil {
|
|
311
|
+
if result.AccountID == "" {
|
|
312
|
+
result.AccountID = extractAccountID(accessClaims)
|
|
313
|
+
}
|
|
314
|
+
if result.Email == "" {
|
|
315
|
+
result.Email = extractEmail(accessClaims)
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
return result, nil
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
type deviceAuthResponse struct {
|
|
324
|
+
DeviceAuthID string
|
|
325
|
+
UserCode string
|
|
326
|
+
Interval time.Duration
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
func requestDeviceAuth() (*deviceAuthResponse, error) {
|
|
330
|
+
body, _ := json.Marshal(map[string]string{"client_id": clientID})
|
|
331
|
+
req, err := http.NewRequest(http.MethodPost, deviceUserCodeURL, strings.NewReader(string(body)))
|
|
332
|
+
if err != nil {
|
|
333
|
+
return nil, err
|
|
334
|
+
}
|
|
335
|
+
req.Header.Set("Content-Type", "application/json")
|
|
336
|
+
req.Header.Set("Accept", "application/json")
|
|
337
|
+
|
|
338
|
+
client := &http.Client{Timeout: 20 * time.Second}
|
|
339
|
+
resp, err := client.Do(req)
|
|
340
|
+
if err != nil {
|
|
341
|
+
return nil, err
|
|
342
|
+
}
|
|
343
|
+
defer resp.Body.Close()
|
|
344
|
+
|
|
345
|
+
data, err := io.ReadAll(resp.Body)
|
|
346
|
+
if err != nil {
|
|
347
|
+
return nil, err
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
351
|
+
return nil, fmt.Errorf("device auth init failed (status %d): %s", resp.StatusCode, strings.TrimSpace(string(data)))
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
var payload struct {
|
|
355
|
+
DeviceAuthID string `json:"device_auth_id"`
|
|
356
|
+
UserCode string `json:"user_code"`
|
|
357
|
+
Interval string `json:"interval"`
|
|
358
|
+
}
|
|
359
|
+
if err := json.Unmarshal(data, &payload); err != nil {
|
|
360
|
+
return nil, fmt.Errorf("failed to parse device auth response: %w", err)
|
|
361
|
+
}
|
|
362
|
+
if strings.TrimSpace(payload.DeviceAuthID) == "" || strings.TrimSpace(payload.UserCode) == "" {
|
|
363
|
+
return nil, fmt.Errorf("device auth response missing required fields")
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
seconds := 5
|
|
367
|
+
if parsed, convErr := strconv.Atoi(strings.TrimSpace(payload.Interval)); convErr == nil && parsed > 0 {
|
|
368
|
+
seconds = parsed
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
return &deviceAuthResponse{
|
|
372
|
+
DeviceAuthID: strings.TrimSpace(payload.DeviceAuthID),
|
|
373
|
+
UserCode: strings.TrimSpace(payload.UserCode),
|
|
374
|
+
Interval: time.Duration(seconds) * time.Second,
|
|
375
|
+
}, nil
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
func pollForDeviceAuthorization(device *deviceAuthResponse) (string, string, error) {
|
|
379
|
+
if device == nil {
|
|
380
|
+
return "", "", fmt.Errorf("device auth context is nil")
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
client := &http.Client{Timeout: 20 * time.Second}
|
|
384
|
+
ticker := time.NewTicker(device.Interval + headlessPollingSafetyMargin)
|
|
385
|
+
defer ticker.Stop()
|
|
386
|
+
timeout := time.After(10 * time.Minute)
|
|
387
|
+
|
|
388
|
+
bodyPayload, _ := json.Marshal(map[string]string{
|
|
389
|
+
"device_auth_id": device.DeviceAuthID,
|
|
390
|
+
"user_code": device.UserCode,
|
|
391
|
+
})
|
|
392
|
+
|
|
393
|
+
for {
|
|
394
|
+
select {
|
|
395
|
+
case <-timeout:
|
|
396
|
+
return "", "", fmt.Errorf("headless oauth timed out")
|
|
397
|
+
case <-ticker.C:
|
|
398
|
+
req, err := http.NewRequest(http.MethodPost, deviceTokenURL, strings.NewReader(string(bodyPayload)))
|
|
399
|
+
if err != nil {
|
|
400
|
+
return "", "", err
|
|
401
|
+
}
|
|
402
|
+
req.Header.Set("Content-Type", "application/json")
|
|
403
|
+
req.Header.Set("Accept", "application/json")
|
|
404
|
+
|
|
405
|
+
resp, err := client.Do(req)
|
|
406
|
+
if err != nil {
|
|
407
|
+
return "", "", err
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
data, readErr := io.ReadAll(resp.Body)
|
|
411
|
+
_ = resp.Body.Close()
|
|
412
|
+
if readErr != nil {
|
|
413
|
+
return "", "", readErr
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound {
|
|
417
|
+
continue
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
421
|
+
return "", "", fmt.Errorf("device auth polling failed (status %d): %s", resp.StatusCode, strings.TrimSpace(string(data)))
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
var payload struct {
|
|
425
|
+
AuthorizationCode string `json:"authorization_code"`
|
|
426
|
+
CodeVerifier string `json:"code_verifier"`
|
|
427
|
+
}
|
|
428
|
+
if err := json.Unmarshal(data, &payload); err != nil {
|
|
429
|
+
return "", "", fmt.Errorf("failed to parse device auth poll response: %w", err)
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
code := strings.TrimSpace(payload.AuthorizationCode)
|
|
433
|
+
verifier := strings.TrimSpace(payload.CodeVerifier)
|
|
434
|
+
if code == "" || verifier == "" {
|
|
435
|
+
return "", "", fmt.Errorf("device auth poll response missing authorization code")
|
|
436
|
+
}
|
|
437
|
+
return code, verifier, nil
|
|
438
|
+
}
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
func openBrowser(targetURL string) error {
|
|
443
|
+
var cmd *exec.Cmd
|
|
444
|
+
switch runtime.GOOS {
|
|
445
|
+
case "windows":
|
|
446
|
+
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", targetURL)
|
|
447
|
+
case "darwin":
|
|
448
|
+
cmd = exec.Command("open", targetURL)
|
|
449
|
+
default:
|
|
450
|
+
cmd = exec.Command("xdg-open", targetURL)
|
|
451
|
+
}
|
|
452
|
+
return cmd.Start()
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
func generateRandomString(length int) string {
|
|
456
|
+
b := make([]byte, length)
|
|
457
|
+
_, _ = rand.Read(b)
|
|
458
|
+
return base64.RawURLEncoding.EncodeToString(b)[:length]
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
func generateCodeChallenge(verifier string) string {
|
|
462
|
+
h := sha256.Sum256([]byte(verifier))
|
|
463
|
+
return base64.RawURLEncoding.EncodeToString(h[:])
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
func decodeJWTClaims(token string) map[string]any {
|
|
467
|
+
token = strings.TrimSpace(token)
|
|
468
|
+
parts := strings.Split(token, ".")
|
|
469
|
+
if len(parts) != 3 {
|
|
470
|
+
return nil
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
payload := parts[1]
|
|
474
|
+
decoded, err := base64.RawURLEncoding.DecodeString(payload)
|
|
475
|
+
if err != nil {
|
|
476
|
+
decoded, err = base64.URLEncoding.DecodeString(payload)
|
|
477
|
+
if err != nil {
|
|
478
|
+
return nil
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
claims := map[string]any{}
|
|
483
|
+
if err := json.Unmarshal(decoded, &claims); err != nil {
|
|
484
|
+
return nil
|
|
485
|
+
}
|
|
486
|
+
return claims
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
func extractAccountID(claims map[string]any) string {
|
|
490
|
+
if claims == nil {
|
|
491
|
+
return ""
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
if raw, ok := claims["chatgpt_account_id"].(string); ok {
|
|
495
|
+
return strings.TrimSpace(raw)
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
if nested, ok := claims["https://api.openai.com/auth"].(map[string]any); ok {
|
|
499
|
+
if raw, ok := nested["chatgpt_account_id"].(string); ok {
|
|
500
|
+
return strings.TrimSpace(raw)
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
if organizations, ok := claims["organizations"].([]any); ok && len(organizations) > 0 {
|
|
505
|
+
if org, ok := organizations[0].(map[string]any); ok {
|
|
506
|
+
if raw, ok := org["id"].(string); ok {
|
|
507
|
+
return strings.TrimSpace(raw)
|
|
508
|
+
}
|
|
509
|
+
}
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
return ""
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
func extractEmail(claims map[string]any) string {
|
|
516
|
+
if claims == nil {
|
|
517
|
+
return ""
|
|
518
|
+
}
|
|
519
|
+
if raw, ok := claims["email"].(string); ok {
|
|
520
|
+
return strings.TrimSpace(raw)
|
|
521
|
+
}
|
|
522
|
+
return ""
|
|
523
|
+
}
|