Skip to content

Commit 97cbff1

Browse files
committed
drpc: add changes for auth interceptor changes
1 parent 23bcd50 commit 97cbff1

File tree

5 files changed

+61
-0
lines changed

5 files changed

+61
-0
lines changed

drpcclient/clientconn.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55

66
"storj.io/drpc"
7+
"storj.io/drpc/drpcmetadata"
78
)
89

910
// ClientConn represents a DRPC client connection, with support for configuring the
@@ -33,6 +34,9 @@ func finalInvoker(ctx context.Context, rpc string, enc drpc.Encoding, in, out dr
3334
}
3435

3536
func (c *ClientConn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
37+
if c.dopts.perRPCMetadata != nil {
38+
ctx = drpcmetadata.AddPairs(ctx, c.dopts.perRPCMetadata)
39+
}
3640
if c.dopts.unaryInt != nil {
3741
return c.dopts.unaryInt(ctx, rpc, enc, in, out, c, finalInvoker)
3842
}
@@ -45,6 +49,9 @@ func finalStreamer(ctx context.Context, rpc string, enc drpc.Encoding, cc *Clien
4549
}
4650

4751
func (c *ClientConn) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
52+
if c.dopts.perRPCMetadata != nil {
53+
ctx = drpcmetadata.AddPairs(ctx, c.dopts.perRPCMetadata)
54+
}
4855
if c.dopts.streamInt != nil {
4956
return c.dopts.streamInt(ctx, rpc, enc, c, finalStreamer)
5057
}

drpcclient/dialoptions.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ type dialOptions struct {
88

99
unaryInts []UnaryClientInterceptor
1010
streamInts []StreamClientInterceptor
11+
12+
perRPCMetadata map[string]string
1113
}
1214

1315
// DialOption configures how we set up the client connection.
@@ -32,3 +34,9 @@ func WithChainStreamInterceptor(ints ...StreamClientInterceptor) DialOption {
3234
opt.streamInts = append(opt.streamInts, ints...)
3335
}
3436
}
37+
38+
func WithPerRPCMetadata(metadata map[string]string) DialOption {
39+
return func(opt *dialOptions) {
40+
opt.perRPCMetadata = metadata
41+
}
42+
}

drpcctx/tlscert.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (C) 2025 Storj Labs, Inc.
2+
// See LICENSE for copying information.
3+
4+
package drpcctx
5+
6+
import (
7+
"context"
8+
"crypto/x509"
9+
)
10+
11+
// TLSPeerCertKey is used to store TLS info in the context.
12+
type TLSPeerCertKey struct{}
13+
14+
// WithPeerCertificate associates the TLS info with the context.
15+
func WithPeerCertificate(ctx context.Context, certificate *x509.Certificate) context.Context {
16+
return context.WithValue(ctx, TLSPeerCertKey{}, certificate)
17+
}
18+
19+
// GetPeerCertificate returns the TLS info associated with the context and a bool
20+
// indicating if they existed.
21+
func GetPeerCertificate(ctx context.Context) (*x509.Certificate, bool) {
22+
tlsInfo, ok := ctx.Value(TLSPeerCertKey{}).(*x509.Certificate)
23+
return tlsInfo, ok
24+
}

drpcmetadata/metadata.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,12 @@ func Get(ctx context.Context) (map[string]string, bool) {
8888
metadata, ok := ctx.Value(metadataKey{}).(map[string]string)
8989
return metadata, ok
9090
}
91+
92+
func GetValue(ctx context.Context, key string) (string, bool) {
93+
metadata, ok := Get(ctx)
94+
if !ok {
95+
return "", false
96+
}
97+
val, ok := metadata[key]
98+
return val, ok
99+
}

drpcserver/server.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package drpcserver
55

66
import (
77
"context"
8+
"crypto/tls"
89
"net"
910
"sync"
1011
"time"
@@ -92,6 +93,18 @@ func (s *Server) getStats(rpc string) *drpcstats.Stats {
9293

9394
// ServeOne serves a single set of rpcs on the provided transport.
9495
func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) {
96+
// Check if the transport is a TLS connection
97+
if tlsConn, ok := tr.(*tls.Conn); ok {
98+
err := tlsConn.Handshake()
99+
if err != nil {
100+
return err
101+
}
102+
state := tlsConn.ConnectionState()
103+
if len(state.PeerCertificates) > 0 {
104+
ctx = drpcctx.WithPeerCertificate(ctx, state.PeerCertificates[0])
105+
}
106+
}
107+
95108
man := drpcmanager.NewWithOptions(tr, s.opts.Manager)
96109
defer func() { err = errs.Combine(err, man.Close()) }()
97110

0 commit comments

Comments
 (0)