Skip to content

Commit a933acd

Browse files
committed
Enforce discard limits on readers
This enforces limits on discard to avoid unbounded reads. Where resources were already exhausted no further reads are done and discards have been removed. These discards were an optimization to reuse connections. When a stream is partially read all subsequent reads will now return EOF errors to avoid reading in a corrupted state. Signed-off-by: Edward McFarlane <emcfarlane@buf.build>
1 parent 145b279 commit a933acd

File tree

6 files changed

+30
-48
lines changed

6 files changed

+30
-48
lines changed

compression.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,13 @@ func (c *compressionPool) Decompress(dst *bytes.Buffer, src *bytes.Buffer, readM
9696
}
9797
return errorf(CodeInvalidArgument, "decompress: %w", err)
9898
}
99-
if readMaxBytes > 0 && bytesRead > readMaxBytes {
100-
discardedBytes, err := io.Copy(io.Discard, decompressor)
101-
_ = c.putDecompressor(decompressor)
102-
if err != nil {
103-
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", readMaxBytes, err)
104-
}
105-
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, readMaxBytes)
106-
}
10799
if err := c.putDecompressor(decompressor); err != nil {
108100
return errorf(CodeUnknown, "recycle decompressor: %w", err)
109101
}
102+
if readMaxBytes > 0 && bytesRead > readMaxBytes {
103+
// Resource is exhausted, fail fast without reading more data from the reader.
104+
return errorf(CodeResourceExhausted, "decompressed message size is larger than configured max %d", readMaxBytes)
105+
}
110106
return nil
111107
}
112108

connect_ext_test.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,6 @@ func TestHandlerWithReadMaxBytes(t *testing.T) {
11971197
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
11981198
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
11991199
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
1200-
assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes)))
12011200
})
12021201
t.Run("read_max_large", func(t *testing.T) {
12031202
t.Parallel()
@@ -1206,16 +1205,14 @@ func TestHandlerWithReadMaxBytes(t *testing.T) {
12061205
}
12071206
// Serializes to much larger than readMaxBytes (5 MiB)
12081207
pingRequest := &pingv1.PingRequest{Text: strings.Repeat("abcde", 1024*1024)}
1209-
expectedSize := proto.Size(pingRequest)
12101208
// With gzip request compression, the error should indicate the envelope size (before decompression) is too large.
12111209
if compressed {
1212-
expectedSize = gzipCompressedSize(t, pingRequest)
1210+
expectedSize := gzipCompressedSize(t, pingRequest)
12131211
assert.True(t, expectedSize > readMaxBytes, assert.Sprintf("expected compressed size %d > %d", expectedSize, readMaxBytes))
12141212
}
12151213
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
12161214
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
12171215
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
1218-
assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes))
12191216
})
12201217
}
12211218
newHTTP2Server := func(t *testing.T) *memhttp.Server {
@@ -1378,7 +1375,6 @@ func TestClientWithReadMaxBytes(t *testing.T) {
13781375
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
13791376
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
13801377
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
1381-
assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes)))
13821378
})
13831379
t.Run("read_max_large", func(t *testing.T) {
13841380
t.Parallel()
@@ -1397,7 +1393,6 @@ func TestClientWithReadMaxBytes(t *testing.T) {
13971393
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
13981394
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
13991395
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
1400-
assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes))
14011396
})
14021397
}
14031398
t.Run("connect", func(t *testing.T) {

envelope.go

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,13 @@ type envelopeReader struct {
228228
compressionPool *compressionPool
229229
bufferPool *bufferPool
230230
readMaxBytes int
231+
isEOF bool
231232
}
232233

233234
func (r *envelopeReader) Unmarshal(message any) *Error {
235+
if r.isEOF {
236+
return NewError(CodeInternal, io.EOF)
237+
}
234238
buffer := r.bufferPool.Get()
235239
var dontRelease *bytes.Buffer
236240
defer func() {
@@ -240,25 +244,20 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
240244
}()
241245

242246
env := &envelope{Data: buffer}
243-
err := r.Read(env)
244-
switch {
245-
case err == nil && env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil:
247+
if err := r.Read(env); err != nil {
248+
// Mark the reader as EOF so that subsequent reads return EOF.
249+
r.isEOF = true
250+
return err
251+
}
252+
if env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil {
246253
return errorf(
247254
CodeInternal,
248255
"protocol error: sent compressed message without compression support",
249256
)
250-
case err == nil &&
251-
(env.Flags == 0 || env.Flags == flagEnvelopeCompressed) &&
252-
env.Data.Len() == 0:
257+
} else if (env.Flags == 0 || env.Flags == flagEnvelopeCompressed) && env.Data.Len() == 0 {
253258
// This is a standard message (because none of the top 7 bits are set) and
254259
// there's no data, so the zero value of the message is correct.
255260
return nil
256-
case err != nil && errors.Is(err, io.EOF):
257-
// The stream has ended. Propagate the EOF to the caller.
258-
return err
259-
case err != nil:
260-
// Something's wrong.
261-
return err
262261
}
263262

264263
data := env.Data
@@ -317,7 +316,7 @@ func (r *envelopeReader) Read(env *envelope) *Error {
317316
// The stream ended cleanly. That's expected, but we need to propagate an EOF
318317
// to the user so that they know that the stream has ended. We shouldn't
319318
// add any alarming text about protocol errors, though.
320-
return NewError(CodeUnknown, err)
319+
return NewError(CodeInternal, err)
321320
}
322321
err = wrapIfMaxBytesError(err, "read 5 byte message prefix")
323322
err = wrapIfContextDone(r.ctx, err)
@@ -332,12 +331,8 @@ func (r *envelopeReader) Read(env *envelope) *Error {
332331
}
333332
size := int64(binary.BigEndian.Uint32(prefixes[1:5]))
334333
if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) {
335-
n, err := io.CopyN(io.Discard, r.reader, size)
336-
r.bytesRead += n
337-
if err != nil && !errors.Is(err, io.EOF) {
338-
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", r.readMaxBytes, err)
339-
}
340-
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", size, r.readMaxBytes)
334+
// Resource is exhausted, fail fast without reading more data from the stream.
335+
return errorf(CodeResourceExhausted, "received message size %d is larger than configured max %d", size, r.readMaxBytes)
341336
}
342337
// We've read the prefix, so we know how many bytes to expect.
343338
// CopyN will return an error if it doesn't read the requested

protocol.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,12 @@ func isCommaOrSpace(c rune) bool {
287287
}
288288

289289
func discard(reader io.Reader) (int64, error) {
290-
if lr, ok := reader.(*io.LimitedReader); ok {
291-
return io.Copy(io.Discard, lr)
292-
}
293290
// We don't want to get stuck throwing data away forever, so limit how much
294291
// we're willing to do here.
295-
lr := &io.LimitedReader{R: reader, N: discardLimit}
292+
lr, ok := reader.(*io.LimitedReader)
293+
if !ok {
294+
lr = &io.LimitedReader{R: reader, N: discardLimit}
295+
}
296296
return io.Copy(io.Discard, lr)
297297
}
298298

protocol_connect.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,19 +1088,19 @@ type connectUnaryUnmarshaler struct {
10881088
codec Codec
10891089
compressionPool *compressionPool
10901090
bufferPool *bufferPool
1091-
alreadyRead bool
10921091
readMaxBytes int
1092+
isEOF bool
10931093
}
10941094

10951095
func (u *connectUnaryUnmarshaler) Unmarshal(message any) *Error {
10961096
return u.UnmarshalFunc(message, u.codec.Unmarshal)
10971097
}
10981098

10991099
func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]byte, any) error) *Error {
1100-
if u.alreadyRead {
1100+
if u.isEOF {
11011101
return NewError(CodeInternal, io.EOF)
11021102
}
1103-
u.alreadyRead = true
1103+
u.isEOF = true
11041104
data := u.bufferPool.Get()
11051105
defer u.bufferPool.Put(data)
11061106
reader := u.reader
@@ -1118,12 +1118,8 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by
11181118
return errorf(CodeUnknown, "read message: %w", err)
11191119
}
11201120
if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) {
1121-
// Attempt to read to end in order to allow connection re-use
1122-
discardedBytes, err := io.Copy(io.Discard, u.reader)
1123-
if err != nil {
1124-
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", u.readMaxBytes, err)
1125-
}
1126-
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, u.readMaxBytes)
1121+
// Resource is exhausted, fail fast without reading more data from the stream.
1122+
return errorf(CodeResourceExhausted, "message size is larger than configured max %d", u.readMaxBytes)
11271123
}
11281124
if data.Len() > 0 && u.compressionPool != nil {
11291125
decompressed := u.bufferPool.Get()

protocol_grpc.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ func (g *grpcClient) NewConn(
319319
}
320320
} else {
321321
conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header {
322-
// To access HTTP trailers, we need to read the body to EOF.
323-
_, _ = discard(call)
322+
// Caller must guarantee the body is read to EOF to access
323+
// trailers.
324324
return call.ResponseTrailer()
325325
}
326326
}

0 commit comments

Comments
 (0)