diff --git a/cmd/src/login_oauth.go b/cmd/src/login_oauth.go index 1e404d6447..a1005a1684 100644 --- a/cmd/src/login_oauth.go +++ b/cmd/src/login_oauth.go @@ -15,16 +15,40 @@ import ( "github.com/sourcegraph/src-cli/internal/oauth" ) -var loadStoredOAuthToken = oauth.LoadToken +var ( + loadStoredOAuthToken = oauth.LoadToken + storeOAuthToken = oauth.StoreToken +) func runOAuthLogin(ctx context.Context, p loginParams) error { - client, err := oauthLoginClient(ctx, p) + client, loadedFromStore, err := oauthLoginClient(ctx, p) if err != nil { printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err)) fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL)) return cmderrors.ExitCode1 } + if loadedFromStore { + username, validateErr := currentUsername(ctx, client) + if validateErr == nil && username != "" { + printAuthenticatedUser(p.out, username, p.cfg.endpointURL) + fmt.Fprintln(p.out) + fmt.Fprint(p.out, "✔︎ Authenticated with OAuth credentials") + fmt.Fprintln(p.out) + return nil + } + + fmt.Fprintln(p.out) + fmt.Fprintln(p.out, "⚠️ Warning: Stored OAuth credentials could not be verified. Starting a new OAuth device flow.") + + client, err = newOAuthLoginClient(ctx, p) + if err != nil { + printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err)) + fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL)) + return cmderrors.ExitCode1 + } + } + if err := validateCurrentUser(ctx, client, p.out, p.cfg.endpointURL); err != nil { return err } @@ -38,18 +62,23 @@ func runOAuthLogin(ctx context.Context, p loginParams) error { // oauthLoginClient returns a api.Client with the OAuth token set. It will check secret storage for a token // and use it if one is present. // If no token is found, it will start a OAuth Device flow to get a token and storage in secret storage. -func oauthLoginClient(ctx context.Context, p loginParams) (api.Client, error) { - // if we have a stored token, used it. Otherwise run the device flow +func oauthLoginClient(ctx context.Context, p loginParams) (api.Client, bool, error) { + // if we have a stored token, use it. Otherwise run the device flow if token, err := loadStoredOAuthToken(ctx, p.cfg.endpointURL); err == nil { - return newOAuthAPIClient(p, token), nil + return newOAuthAPIClient(p, token), true, nil } + client, err := newOAuthLoginClient(ctx, p) + return client, false, err +} + +func newOAuthLoginClient(ctx context.Context, p loginParams) (api.Client, error) { token, err := runOAuthDeviceFlow(ctx, p.cfg.endpointURL, p.out, p.oauthClient) if err != nil { return nil, err } - if err := oauth.StoreToken(ctx, token); err != nil { + if err := storeOAuthToken(ctx, token); err != nil { fmt.Fprintln(p.out) fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err) } diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index 5dba8b464b..ac1e73c609 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -137,11 +137,80 @@ func TestLogin(t *testing.T) { t.Errorf("got output %q, want %q", gotOut, wantOut) } }) + + t.Run("invalid stored oauth token restarts device flow", func(t *testing.T) { + var authHeaders []string + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeaders = append(authHeaders, r.Header.Get("Authorization")) + if r.Header.Get("Authorization") != "Bearer new-oauth-token" { + http.Error(w, "", http.StatusUnauthorized) + return + } + fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`) + })) + defer s.Close() + + restoreStoredOAuthLoader(t, func(_ context.Context, _ *url.URL) (*oauth.Token, error) { + return &oauth.Token{ + Endpoint: s.URL, + ClientID: oauth.DefaultClientID, + AccessToken: "old-oauth-token", + ExpiresAt: time.Now().Add(time.Hour), + }, nil + }) + restoreOAuthTokenStore(t, func(context.Context, *oauth.Token) error { return nil }) + + u, _ := url.ParseRequestURI(s.URL) + startCalled := false + pollCalled := false + var out bytes.Buffer + err := loginCmd(context.Background(), loginParams{ + cfg: &config{endpointURL: u}, + client: (&config{endpointURL: u}).apiClient(nil, io.Discard), + out: &out, + oauthClient: fakeOAuthClient{ + startCalled: &startCalled, + deviceResp: &oauth.DeviceAuthResponse{ + DeviceCode: "device-code", + ExpiresIn: 60, + }, + pollCalled: &pollCalled, + pollResp: &oauth.TokenResponse{ + AccessToken: "new-oauth-token", + ExpiresIn: 3600, + TokenType: "Bearer", + }, + }, + }) + if err != nil { + t.Fatal(err) + } + if !startCalled || !pollCalled { + t.Fatal("expected invalid stored oauth token to restart device flow") + } + if len(authHeaders) != 2 || authHeaders[0] != "Bearer old-oauth-token" || authHeaders[1] != "Bearer new-oauth-token" { + t.Fatalf("Authorization headers = %q, want old token then new token", authHeaders) + } + gotOut := out.String() + for _, want := range []string{ + "⚠️ Warning: Stored OAuth credentials could not be verified. Starting a new OAuth device flow.", + "Waiting for authorization... DONE", + "✔︎ Authenticated as alice on " + s.URL, + "✔︎ Authenticated with OAuth credentials", + } { + if !strings.Contains(gotOut, want) { + t.Errorf("got output %q, want it to contain %q", gotOut, want) + } + } + }) } type fakeOAuthClient struct { startErr error startCalled *bool + deviceResp *oauth.DeviceAuthResponse + pollCalled *bool + pollResp *oauth.TokenResponse } func (f fakeOAuthClient) ClientID() string { @@ -156,10 +225,22 @@ func (f fakeOAuthClient) Start(context.Context, *url.URL, []string) (*oauth.Devi if f.startCalled != nil { *f.startCalled = true } - return nil, f.startErr + if f.startErr != nil { + return nil, f.startErr + } + if f.deviceResp != nil { + return f.deviceResp, nil + } + return nil, fmt.Errorf("unexpected call to Start") } func (f fakeOAuthClient) Poll(context.Context, *url.URL, string, time.Duration, int) (*oauth.TokenResponse, error) { + if f.pollCalled != nil { + *f.pollCalled = true + } + if f.pollResp != nil { + return f.pollResp, nil + } return nil, fmt.Errorf("unexpected call to Poll") } @@ -242,3 +323,13 @@ func restoreStoredOAuthLoader(t *testing.T, loader func(context.Context, *url.UR loadStoredOAuthToken = prev }) } + +func restoreOAuthTokenStore(t *testing.T, store func(context.Context, *oauth.Token) error) { + t.Helper() + + prev := storeOAuthToken + storeOAuthToken = store + t.Cleanup(func() { + storeOAuthToken = prev + }) +} diff --git a/cmd/src/login_validate.go b/cmd/src/login_validate.go index 095ea7ab22..384e0c534a 100644 --- a/cmd/src/login_validate.go +++ b/cmd/src/login_validate.go @@ -16,11 +16,8 @@ func runValidatedLogin(ctx context.Context, p loginParams) error { } func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, endpointURL *url.URL) error { - query := `query CurrentUser { currentUser { username } }` - var result struct { - CurrentUser *struct{ Username string } - } - if _, err := client.NewRequest(query, nil).Do(ctx, &result); err != nil { + username, err := currentUsername(ctx, client) + if err != nil { if strings.HasPrefix(err.Error(), "error: 401 Unauthorized") || strings.HasPrefix(err.Error(), "error: 403 Forbidden") { printLoginProblem(out, "Invalid access token.") } else { @@ -31,14 +28,32 @@ func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, return cmderrors.ExitCode1 } - if result.CurrentUser == nil { + if username == "" { // This should never happen; we verified there is an access token, so there should always be // a user. printLoginProblem(out, fmt.Sprintf("Unable to determine user on %s.", endpointURL)) return cmderrors.ExitCode1 } + printAuthenticatedUser(out, username, endpointURL) + return nil +} + +func printAuthenticatedUser(out io.Writer, username string, endpointURL *url.URL) { fmt.Fprintln(out) - fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointURL) + fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", username, endpointURL) fmt.Fprintln(out) - return nil +} + +func currentUsername(ctx context.Context, client api.Client) (string, error) { + query := `query CurrentUser { currentUser { username } }` + var result struct { + CurrentUser *struct{ Username string } + } + if _, err := client.NewRequest(query, nil).Do(ctx, &result); err != nil { + return "", err + } + if result.CurrentUser == nil { + return "", nil + } + return result.CurrentUser.Username, nil }