Skip to content

Commit 6e2f6fb

Browse files
authored
feat: password-less database login (#3885)
1 parent bcbc267 commit 6e2f6fb

File tree

8 files changed

+132
-53
lines changed

8 files changed

+132
-53
lines changed

cmd/root.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ var (
111111
}
112112
}
113113
}
114-
if err := flags.ParseDatabaseConfig(cmd.Flags(), fsys); err != nil {
114+
if err := flags.ParseDatabaseConfig(ctx, cmd.Flags(), fsys); err != nil {
115115
return err
116116
}
117117
// Prepare context

internal/bootstrap/bootstrap.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func Run(ctx context.Context, starter StarterTemplate, fsys afero.Fs, options ..
113113
return err
114114
}
115115
// 6. Push migrations
116-
config := flags.NewDbConfigWithPassword(flags.ProjectRef)
116+
config := flags.NewDbConfigWithPassword(ctx, flags.ProjectRef)
117117
if err := writeDotEnv(keys, config, fsys); err != nil {
118118
fmt.Fprintln(os.Stderr, "Failed to create .env file:", err)
119119
}

internal/link/link.go

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"github.com/spf13/afero"
1515
"github.com/spf13/viper"
1616
"github.com/supabase/cli/internal/utils"
17-
"github.com/supabase/cli/internal/utils/credentials"
1817
"github.com/supabase/cli/internal/utils/flags"
1918
"github.com/supabase/cli/internal/utils/tenant"
2019
"github.com/supabase/cli/pkg/api"
@@ -43,15 +42,9 @@ func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(
4342
LinkServices(ctx, projectRef, keys.Anon, fsys)
4443

4544
// 2. Check database connection
46-
config := flags.GetDbConfigOptionalPassword(projectRef)
47-
if len(config.Password) > 0 {
48-
if err := linkDatabase(ctx, config, fsys, options...); err != nil {
49-
return err
50-
}
51-
// Save database password
52-
if err := credentials.StoreProvider.Set(projectRef, config.Password); err != nil {
53-
fmt.Fprintln(os.Stderr, "Failed to save database password:", err)
54-
}
45+
config := flags.NewDbConfigWithPassword(ctx, projectRef)
46+
if err := linkDatabase(ctx, config, fsys, options...); err != nil {
47+
return err
5548
}
5649

5750
// 3. Save project ref

internal/link/link_test.go

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/supabase/cli/pkg/api"
2222
"github.com/supabase/cli/pkg/migration"
2323
"github.com/supabase/cli/pkg/pgtest"
24+
"github.com/supabase/cli/pkg/pgxv5"
2425
"github.com/zalando/go-keyring"
2526
)
2627

@@ -47,7 +48,9 @@ func TestLinkCommand(t *testing.T) {
4748
// Setup mock postgres
4849
conn := pgtest.NewConn()
4950
defer conn.Close(t)
50-
conn.Query(GET_LATEST_STORAGE_MIGRATION).
51+
conn.Query(pgxv5.SET_SESSION_ROLE).
52+
Reply("SET ROLE").
53+
Query(GET_LATEST_STORAGE_MIGRATION).
5154
Reply("SELECT 1", []interface{}{"custom-metadata"})
5255
helper.MockMigrationHistory(conn)
5356
helper.MockSeedHistory(conn)
@@ -92,6 +95,9 @@ func TestLinkCommand(t *testing.T) {
9295
Get("/v1/projects/" + project + "/network-restrictions").
9396
Reply(200).
9497
JSON(api.NetworkRestrictionsResponse{})
98+
gock.New(utils.DefaultApiHost).
99+
Post("/v1/projects/" + project + "/database/query").
100+
Reply(http.StatusCreated)
95101
// Link versions
96102
auth := tenant.HealthResponse{Version: "v2.74.2"}
97103
gock.New("https://" + utils.GetSupabaseHost(project)).
@@ -158,8 +164,11 @@ func TestLinkCommand(t *testing.T) {
158164
ReplyError(errors.New("network error"))
159165
gock.New(utils.DefaultApiHost).
160166
Get("/v1/projects/" + project + "/network-restrictions").
161-
Reply(200).
167+
Reply(http.StatusOK).
162168
JSON(api.NetworkRestrictionsResponse{})
169+
gock.New(utils.DefaultApiHost).
170+
Post("/v1/projects/" + project + "/database/query").
171+
Reply(http.StatusServiceUnavailable)
163172
// Link versions
164173
gock.New("https://" + utils.GetSupabaseHost(project)).
165174
Get("/auth/v1/health").
@@ -181,6 +190,15 @@ func TestLinkCommand(t *testing.T) {
181190
t.Run("throws error on write failure", func(t *testing.T) {
182191
// Setup in-memory fs
183192
fsys := afero.NewReadOnlyFs(afero.NewMemMapFs())
193+
// Setup mock postgres
194+
conn := pgtest.NewConn()
195+
defer conn.Close(t)
196+
conn.Query(pgxv5.SET_SESSION_ROLE).
197+
Reply("SET ROLE").
198+
Query(GET_LATEST_STORAGE_MIGRATION).
199+
Reply("SELECT 1", []interface{}{"custom-metadata"})
200+
helper.MockMigrationHistory(conn)
201+
helper.MockSeedHistory(conn)
184202
// Flush pending mocks after test execution
185203
defer gock.OffAll()
186204
// Mock project status
@@ -212,8 +230,11 @@ func TestLinkCommand(t *testing.T) {
212230
ReplyError(errors.New("network error"))
213231
gock.New(utils.DefaultApiHost).
214232
Get("/v1/projects/" + project + "/network-restrictions").
215-
Reply(200).
233+
Reply(http.StatusOK).
216234
JSON(api.NetworkRestrictionsResponse{})
235+
gock.New(utils.DefaultApiHost).
236+
Post("/v1/projects/" + project + "/database/query").
237+
Reply(http.StatusCreated)
217238
// Link versions
218239
gock.New("https://" + utils.GetSupabaseHost(project)).
219240
Get("/auth/v1/health").
@@ -225,7 +246,7 @@ func TestLinkCommand(t *testing.T) {
225246
Get("/v1/projects").
226247
ReplyError(errors.New("network error"))
227248
// Run test
228-
err := Run(context.Background(), project, fsys)
249+
err := Run(context.Background(), project, fsys, conn.Intercept)
229250
// Check error
230251
assert.ErrorContains(t, err, "operation not permitted")
231252
assert.Empty(t, apitest.ListUnmatchedRequests())

internal/utils/flags/db_url.go

Lines changed: 67 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
package flags
22

33
import (
4+
"bytes"
5+
"context"
46
"crypto/rand"
7+
_ "embed"
58
"fmt"
69
"math/big"
10+
"net/http"
711
"os"
812
"strings"
13+
"text/template"
914

1015
"github.com/go-errors/errors"
1116
"github.com/jackc/pgconn"
@@ -14,7 +19,9 @@ import (
1419
"github.com/spf13/viper"
1520
"github.com/supabase/cli/internal/utils"
1621
"github.com/supabase/cli/internal/utils/credentials"
22+
"github.com/supabase/cli/pkg/api"
1723
"github.com/supabase/cli/pkg/config"
24+
"github.com/supabase/cli/pkg/pgxv5"
1825
)
1926

2027
type connection int
@@ -29,7 +36,7 @@ const (
2936

3037
var DbConfig pgconn.Config
3138

32-
func ParseDatabaseConfig(flagSet *pflag.FlagSet, fsys afero.Fs) error {
39+
func ParseDatabaseConfig(ctx context.Context, flagSet *pflag.FlagSet, fsys afero.Fs) error {
3340
// Changed flags take precedence over default values
3441
var connType connection
3542
if flag := flagSet.Lookup("db-url"); flag != nil && flag.Changed {
@@ -77,7 +84,7 @@ func ParseDatabaseConfig(flagSet *pflag.FlagSet, fsys afero.Fs) error {
7784
if err := LoadConfig(fsys); err != nil {
7885
return err
7986
}
80-
DbConfig = NewDbConfigWithPassword(ProjectRef)
87+
DbConfig = NewDbConfigWithPassword(ctx, ProjectRef)
8188
case proxy:
8289
token, err := utils.LoadAccessTokenFS(fsys)
8390
if err != nil {
@@ -95,23 +102,71 @@ func ParseDatabaseConfig(flagSet *pflag.FlagSet, fsys afero.Fs) error {
95102
return nil
96103
}
97104

98-
func NewDbConfigWithPassword(projectRef string) pgconn.Config {
99-
config := getDbConfig(projectRef)
100-
config.Password = getPassword(projectRef)
101-
return config
105+
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
106+
107+
func RandomString(size int) (string, error) {
108+
data := make([]byte, size)
109+
_, err := rand.Read(data)
110+
if err != nil {
111+
return "", errors.Errorf("failed to read random: %w", err)
112+
}
113+
for i := range data {
114+
n := int(data[i]) % len(letters)
115+
data[i] = letters[n]
116+
}
117+
return string(data), nil
102118
}
103119

104-
func getPassword(projectRef string) string {
105-
if password := viper.GetString("DB_PASSWORD"); len(password) > 0 {
106-
return password
120+
func NewDbConfigWithPassword(ctx context.Context, projectRef string) pgconn.Config {
121+
config := getDbConfig(projectRef)
122+
config.Password = viper.GetString("DB_PASSWORD")
123+
if len(config.Password) > 0 {
124+
return config
107125
}
108-
if password, err := credentials.StoreProvider.Get(projectRef); err == nil {
109-
return password
126+
var err error
127+
if config.Password, err = RandomString(32); err == nil {
128+
newRole := pgconn.Config{
129+
User: pgxv5.CLI_LOGIN_ROLE,
130+
Password: config.Password,
131+
}
132+
if err := initLoginRole(ctx, projectRef, newRole); err == nil {
133+
// Special handling for pooler username
134+
if suffix := "." + projectRef; strings.HasSuffix(config.User, suffix) {
135+
newRole.User += suffix
136+
}
137+
config.User = newRole.User
138+
return config
139+
}
140+
}
141+
if config.Password, err = credentials.StoreProvider.Get(projectRef); err == nil {
142+
return config
110143
}
111144
resetUrl := fmt.Sprintf("%s/project/%s/settings/database", utils.GetSupabaseDashboardURL(), projectRef)
112145
fmt.Fprintln(os.Stderr, "Forgot your password? Reset it from the Dashboard:", utils.Bold(resetUrl))
113146
fmt.Fprint(os.Stderr, "Enter your database password: ")
114-
return credentials.PromptMasked(os.Stdin)
147+
config.Password = credentials.PromptMasked(os.Stdin)
148+
return config
149+
}
150+
151+
var (
152+
//go:embed queries/role.sql
153+
initRoleEmbed string
154+
initRoleTemplate = template.Must(template.New("initRole").Parse(initRoleEmbed))
155+
)
156+
157+
func initLoginRole(ctx context.Context, projectRef string, config pgconn.Config) error {
158+
fmt.Fprintf(os.Stderr, "Initialising %s role...\n", config.User)
159+
var initRoleBuf bytes.Buffer
160+
if err := initRoleTemplate.Option("missingkey=error").Execute(&initRoleBuf, config); err != nil {
161+
return errors.Errorf("failed to exec template: %w", err)
162+
}
163+
body := api.V1RunQueryBody{Query: initRoleBuf.String()}
164+
if resp, err := utils.GetSupabase().V1RunAQueryWithResponse(ctx, projectRef, body); err != nil {
165+
return errors.Errorf("failed to initialise login role: %w", err)
166+
} else if resp.StatusCode() != http.StatusCreated {
167+
return errors.Errorf("unexpected query status %d: %s", resp.StatusCode(), string(resp.Body))
168+
}
169+
return nil
115170
}
116171

117172
const PASSWORD_LENGTH = 16
@@ -148,13 +203,3 @@ func getDbConfig(projectRef string) pgconn.Config {
148203
Database: "postgres",
149204
}
150205
}
151-
152-
func GetDbConfigOptionalPassword(projectRef string) pgconn.Config {
153-
config := getDbConfig(projectRef)
154-
config.Password = viper.GetString("DB_PASSWORD")
155-
if config.Password == "" {
156-
fmt.Fprint(os.Stderr, "Enter your database password (or leave blank to skip): ")
157-
config.Password = credentials.PromptMasked(os.Stdin)
158-
}
159-
return config
160-
}

internal/utils/flags/db_url_test.go

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
package flags
22

33
import (
4+
"context"
45
"os"
56
"testing"
67

78
"github.com/spf13/afero"
89
"github.com/spf13/pflag"
9-
"github.com/spf13/viper"
1010
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/require"
1212
"github.com/supabase/cli/internal/testing/apitest"
1313
"github.com/supabase/cli/internal/utils"
1414
)
1515

1616
func TestParseDatabaseConfig(t *testing.T) {
17+
// Setup valid access token
18+
token := apitest.RandomAccessToken(t)
19+
t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
20+
1721
t.Run("parses direct connection from db-url flag", func(t *testing.T) {
1822
flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
1923
flagSet.String("db-url", "postgres://postgres:password@localhost:5432/postgres", "")
@@ -22,7 +26,7 @@ func TestParseDatabaseConfig(t *testing.T) {
2226

2327
fsys := afero.NewMemMapFs()
2428

25-
err = ParseDatabaseConfig(flagSet, fsys)
29+
err = ParseDatabaseConfig(context.Background(), flagSet, fsys)
2630

2731
assert.NoError(t, err)
2832
assert.Equal(t, "db.example.com", DbConfig.Host)
@@ -44,7 +48,7 @@ func TestParseDatabaseConfig(t *testing.T) {
4448
utils.Config.Db.Port = 54322
4549
utils.Config.Db.Password = "local-password"
4650

47-
err = ParseDatabaseConfig(flagSet, fsys)
51+
err = ParseDatabaseConfig(context.Background(), flagSet, fsys)
4852

4953
assert.NoError(t, err)
5054
assert.Equal(t, "localhost", DbConfig.Host)
@@ -66,7 +70,7 @@ func TestParseDatabaseConfig(t *testing.T) {
6670
err = afero.WriteFile(fsys, utils.ProjectRefPath, []byte(project), 0644)
6771
require.NoError(t, err)
6872

69-
err = ParseDatabaseConfig(flagSet, fsys)
73+
err = ParseDatabaseConfig(context.Background(), flagSet, fsys)
7074

7175
assert.NoError(t, err)
7276
assert.Equal(t, utils.GetSupabaseDbHost(project), DbConfig.Host)
@@ -105,14 +109,3 @@ func TestPromptPassword(t *testing.T) {
105109
assert.NotEqual(t, "", password)
106110
})
107111
}
108-
109-
func TestGetDbConfigOptionalPassword(t *testing.T) {
110-
t.Run("uses environment variable when available", func(t *testing.T) {
111-
viper.Set("DB_PASSWORD", "env-password")
112-
projectRef := apitest.RandomProjectRef()
113-
114-
config := GetDbConfigOptionalPassword(projectRef)
115-
116-
assert.Equal(t, "env-password", config.Password)
117-
})
118-
}

internal/utils/flags/queries/role.sql

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
do $func$
2+
begin
3+
if not exists (
4+
select 1
5+
from pg_roles
6+
where rolname = '{{ .User }}'
7+
)
8+
then
9+
create role "{{ .User }}" noinherit login noreplication in role postgres;
10+
end if;
11+
execute format(
12+
$$alter role "{{ .User }}" with password '{{ .Password }}' valid until %L$$,
13+
now() + interval '5 minutes'
14+
);
15+
end
16+
$func$ language plpgsql;

pkg/pgxv5/connect.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,18 @@ import (
44
"context"
55
"fmt"
66
"os"
7+
"strings"
78

89
"github.com/go-errors/errors"
910
"github.com/jackc/pgconn"
1011
"github.com/jackc/pgx/v4"
1112
)
1213

14+
const (
15+
CLI_LOGIN_ROLE = "cli_login_postgres"
16+
SET_SESSION_ROLE = "SET SESSION ROLE postgres"
17+
)
18+
1319
// Extends pgx.Connect with support for programmatically overriding parsed config
1420
func Connect(ctx context.Context, connString string, options ...func(*pgx.ConnConfig)) (*pgx.Conn, error) {
1521
// Parse connection url
@@ -20,6 +26,11 @@ func Connect(ctx context.Context, connString string, options ...func(*pgx.ConnCo
2026
config.OnNotice = func(pc *pgconn.PgConn, n *pgconn.Notice) {
2127
fmt.Fprintf(os.Stderr, "%s (%s): %s\n", n.Severity, n.Code, n.Message)
2228
}
29+
if strings.HasPrefix(config.User, CLI_LOGIN_ROLE) {
30+
config.AfterConnect = func(ctx context.Context, pgconn *pgconn.PgConn) error {
31+
return pgconn.Exec(ctx, SET_SESSION_ROLE).Close()
32+
}
33+
}
2334
// Apply config overrides
2435
for _, op := range options {
2536
op(config)

0 commit comments

Comments
 (0)