-
Notifications
You must be signed in to change notification settings - Fork 120
Add RecoverInterceptor as alternative to WithRecover #824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,48 +17,292 @@ package connect | |||||||||||
import ( | ||||||||||||
"context" | ||||||||||||
"net/http" | ||||||||||||
"sync/atomic" | ||||||||||||
) | ||||||||||||
|
||||||||||||
// recoverHandlerInterceptor lets handlers trap panics, perform side effects | ||||||||||||
// (like emitting logs or metrics), and present a friendlier error message to | ||||||||||||
// clients. | ||||||||||||
type recoverHandlerInterceptor struct { | ||||||||||||
Interceptor | ||||||||||||
// RecoverInterceptor is an interceptor that recovers from panics. The | ||||||||||||
// supplied function receives the context and request details. | ||||||||||||
// | ||||||||||||
// For streaming RPCs, req.Any() may return nil. It will always be nil | ||||||||||||
// for client-streaming or bidi-streaming RPCs, since there could be | ||||||||||||
// zero or even multiple request messages for such RPCs. For | ||||||||||||
// server-streaming RPCs, it will be nil if the panic occurred before | ||||||||||||
// the request message was received, which can happen if a panic occurs | ||||||||||||
// in an interceptor before the RPC handler method is invoked. | ||||||||||||
// | ||||||||||||
// Similarly, for streaming RPCs, req.Header() may return nil. This | ||||||||||||
// could happen in clients when the panic that is recovered occurs | ||||||||||||
// before the stream is actually created and before request headers are | ||||||||||||
// even allocated. | ||||||||||||
// | ||||||||||||
// Applications will generally want to add this interceptor first, which | ||||||||||||
// means it will actually be the last to handle any results from the | ||||||||||||
// RPC handler. This allows for recovering from the panics not only in | ||||||||||||
// the handler but also in any other interceptors. | ||||||||||||
// | ||||||||||||
// The recovered value will never be nil. If panic was called with a nil | ||||||||||||
// value, the recovered value will be a *[runtime.PanicNilError]. It must | ||||||||||||
// return an error to send back to the client. If it returns nil, an | ||||||||||||
// *Error with a code of CodeInternal will ne synthesized. The function | ||||||||||||
// may also log the panic, emit metrics, or execute other error-handling | ||||||||||||
// logic. The function must be safe to call concurrently. | ||||||||||||
// | ||||||||||||
// By default, handlers don't recover from panics. Because the standard | ||||||||||||
// library's [http.Server] recovers from panics by default, this option | ||||||||||||
// isn't usually necessary to prevent crashes. Instead, it helps servers | ||||||||||||
// collect RPC-specific data during panics and send a more detailed error | ||||||||||||
// to clients. | ||||||||||||
// | ||||||||||||
// Unlike [WithRecover], this interceptor does not do anything special with | ||||||||||||
// [http.ErrAbortHandler], so the handle function may be called with that as | ||||||||||||
// the panic value. | ||||||||||||
// | ||||||||||||
// Also unlike [WithRecover], which can only be used with handlers, this | ||||||||||||
// interceptor can be used with clients, to recover from any panics caused | ||||||||||||
// by bugs in the interceptor chain. For streaming RPCs, this will recover | ||||||||||||
// from panics that happen in calls to send or receive messages on the | ||||||||||||
// stream or to close the stream. | ||||||||||||
func RecoverInterceptor(handle func(ctx context.Context, req AnyRequest, panicValue any) error) Interceptor { | ||||||||||||
return &recoverHandlerInterceptor{handle: handle} | ||||||||||||
} | ||||||||||||
|
||||||||||||
handle func(context.Context, Spec, http.Header, any) error | ||||||||||||
type recoverHandlerInterceptor struct { | ||||||||||||
handle func(context.Context, AnyRequest, any) error | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (i *recoverHandlerInterceptor) WrapUnary(next UnaryFunc) UnaryFunc { | ||||||||||||
return func(ctx context.Context, req AnyRequest) (_ AnyResponse, retErr error) { | ||||||||||||
if req.Spec().IsClient { | ||||||||||||
return next(ctx, req) | ||||||||||||
} | ||||||||||||
defer func() { | ||||||||||||
if r := recover(); r != nil { | ||||||||||||
// net/http checks for ErrAbortHandler with ==, so we should too. | ||||||||||||
if r == http.ErrAbortHandler { //nolint:errorlint,goerr113 | ||||||||||||
panic(r) //nolint:forbidigo | ||||||||||||
retErr = i.handle(ctx, req, r) | ||||||||||||
if retErr == nil { | ||||||||||||
retErr = errorf(CodeInternal, "handler panicked; but recover handler returned non-nil error") | ||||||||||||
} | ||||||||||||
retErr = i.handle(ctx, req.Spec(), req.Header(), r) | ||||||||||||
} | ||||||||||||
}() | ||||||||||||
res, err := next(ctx, req) | ||||||||||||
return res, err | ||||||||||||
return next(ctx, req) | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (i *recoverHandlerInterceptor) WrapStreamingHandler(next StreamingHandlerFunc) StreamingHandlerFunc { | ||||||||||||
return func(ctx context.Context, conn StreamingHandlerConn) (retErr error) { | ||||||||||||
var streamConn *recoverStreamingHandlerConn | ||||||||||||
if conn.Spec().StreamType == StreamTypeServer { | ||||||||||||
// There will be exactly one request. So we try to capture it | ||||||||||||
// so we can provide it to the recover handle func. | ||||||||||||
streamConn = &recoverStreamingHandlerConn{StreamingHandlerConn: conn} | ||||||||||||
conn = streamConn | ||||||||||||
} | ||||||||||||
|
||||||||||||
defer func() { | ||||||||||||
if r := recover(); r != nil { | ||||||||||||
// net/http checks for ErrAbortHandler with ==, so we should too. | ||||||||||||
if r == http.ErrAbortHandler { //nolint:errorlint,goerr113 | ||||||||||||
panic(r) //nolint:forbidigo | ||||||||||||
if panicVal := recover(); panicVal != nil { | ||||||||||||
var msg any | ||||||||||||
if streamConn != nil { | ||||||||||||
if msgPtr := streamConn.req.Load(); msgPtr != nil { | ||||||||||||
msg = *msgPtr | ||||||||||||
} | ||||||||||||
} | ||||||||||||
retErr = i.handle(ctx, &recoverStreamRequest{conn, msg}, panicVal) | ||||||||||||
if retErr == nil { | ||||||||||||
retErr = errorf(CodeInternal, "handler panicked; but recover handler returned non-nil error") | ||||||||||||
} | ||||||||||||
} | ||||||||||||
}() | ||||||||||||
return next(ctx, conn) | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (i *recoverHandlerInterceptor) WrapStreamingClient(next StreamingClientFunc) StreamingClientFunc { | ||||||||||||
return func(ctx context.Context, spec Spec) (conn StreamingClientConn) { | ||||||||||||
defer func() { | ||||||||||||
if panicVal := recover(); panicVal != nil { | ||||||||||||
err := i.handle(ctx, emptyRequest(spec), panicVal) | ||||||||||||
if err == nil { | ||||||||||||
err = errorf(CodeInternal, "call panicked; but recover handler returned non-nil error") | ||||||||||||
} | ||||||||||||
retErr = i.handle(ctx, conn.Spec(), conn.RequestHeader(), r) | ||||||||||||
conn = &errStreamingClientConn{spec, err} | ||||||||||||
} | ||||||||||||
}() | ||||||||||||
err := next(ctx, conn) | ||||||||||||
return err | ||||||||||||
conn = next(ctx, spec) | ||||||||||||
return &recoverStreamingClientConn{ | ||||||||||||
StreamingClientConn: conn, | ||||||||||||
ctx: ctx, | ||||||||||||
handle: i.handle, | ||||||||||||
} | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
type recoverStreamRequest struct { | ||||||||||||
StreamingHandlerConn | ||||||||||||
msg any | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamRequest) Any() any { | ||||||||||||
return r.msg | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamRequest) Header() http.Header { | ||||||||||||
return r.RequestHeader() | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamRequest) HTTPMethod() string { | ||||||||||||
return http.MethodPost // streams always use POST | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamRequest) internalOnly() { | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamRequest) setRequestMethod(_ string) { | ||||||||||||
// only invoked internally for unary RPCs; safe to ignore | ||||||||||||
} | ||||||||||||
|
||||||||||||
type recoverStreamingHandlerConn struct { | ||||||||||||
StreamingHandlerConn | ||||||||||||
req atomic.Pointer[any] | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamingHandlerConn) Receive(msg any) error { | ||||||||||||
err := r.StreamingHandlerConn.Receive(msg) | ||||||||||||
if err == nil { | ||||||||||||
// Note: The framework instantiates msg, passes it to | ||||||||||||
// this method, and then returns it to the application. | ||||||||||||
// It is possible that the application could mutate the | ||||||||||||
// value, so what we provide to the recover handler would | ||||||||||||
// then differ from the message actually received. But | ||||||||||||
// this is no different than if the RPC handler mutated | ||||||||||||
// the request message for a unary RPC and interceptors | ||||||||||||
// later examined it via Request.Any. So we tolerate the | ||||||||||||
// possibility for server-stream requests, too. | ||||||||||||
r.req.Store(&msg) | ||||||||||||
} | ||||||||||||
return err | ||||||||||||
} | ||||||||||||
|
||||||||||||
type emptyRequest Spec | ||||||||||||
|
||||||||||||
func (e emptyRequest) Any() any { | ||||||||||||
return nil | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e emptyRequest) Spec() Spec { | ||||||||||||
return Spec(e) | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e emptyRequest) Peer() Peer { | ||||||||||||
return Peer{} | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e emptyRequest) Header() http.Header { | ||||||||||||
return nil | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e emptyRequest) HTTPMethod() string { | ||||||||||||
return http.MethodPost | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e emptyRequest) internalOnly() { | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e emptyRequest) setRequestMethod(_ string) { | ||||||||||||
// only invoked internally for unary RPCs; safe to ignore | ||||||||||||
} | ||||||||||||
|
||||||||||||
type errStreamingClientConn struct { | ||||||||||||
spec Spec | ||||||||||||
err error | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e *errStreamingClientConn) Spec() Spec { | ||||||||||||
return e.spec | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e *errStreamingClientConn) Peer() Peer { | ||||||||||||
return Peer{} | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e *errStreamingClientConn) Send(_ any) error { | ||||||||||||
return e.err | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e *errStreamingClientConn) RequestHeader() http.Header { | ||||||||||||
// Clients can add headers before calling Send, so this must be mutable/non-nil. | ||||||||||||
return http.Header{} // TODO: memoize so we never allocate more than one? | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e *errStreamingClientConn) CloseRequest() error { | ||||||||||||
return e.err | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e *errStreamingClientConn) Receive(_ any) error { | ||||||||||||
return e.err | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e *errStreamingClientConn) ResponseHeader() http.Header { | ||||||||||||
return nil | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e *errStreamingClientConn) ResponseTrailer() http.Header { | ||||||||||||
return nil | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (e *errStreamingClientConn) CloseResponse() error { | ||||||||||||
return e.err | ||||||||||||
} | ||||||||||||
|
||||||||||||
type recoverStreamingClientConn struct { | ||||||||||||
StreamingClientConn | ||||||||||||
|
||||||||||||
//nolint:containedctx // must memoize the stream context to pass to recover handler | ||||||||||||
ctx context.Context | ||||||||||||
handle func(context.Context, AnyRequest, any) error | ||||||||||||
req atomic.Pointer[any] | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's fair. I'm generally allergic to But in this case, there is a generic type parameter for the request type requiring it to always be a particular concrete type, so there should be no such risk. So it should be safe to just use There is a slight risk that a caller could call |
||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamingClientConn) Send(msg any) error { | ||||||||||||
if r.Spec().StreamType == StreamTypeServer { | ||||||||||||
// Capture the request message for server-streaming RPCs. | ||||||||||||
r.req.Store(&msg) | ||||||||||||
} | ||||||||||||
return r.invoke(func() error { | ||||||||||||
return r.StreamingClientConn.Send(msg) | ||||||||||||
}) | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamingClientConn) RequestHeader() http.Header { | ||||||||||||
if header := r.StreamingClientConn.RequestHeader(); header != nil { | ||||||||||||
return header | ||||||||||||
} | ||||||||||||
// Clients can add headers before calling Send, so this must be mutable/non-nil. | ||||||||||||
// We do this not to recover from a panic but in the hopes of preventing panics in the caller. | ||||||||||||
return http.Header{} // TODO: memoize so we never allocate more than one? | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamingClientConn) CloseRequest() error { | ||||||||||||
return r.invoke(r.StreamingClientConn.CloseRequest) | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamingClientConn) Receive(msg any) error { | ||||||||||||
return r.invoke(func() error { | ||||||||||||
return r.StreamingClientConn.Receive(msg) | ||||||||||||
}) | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamingClientConn) CloseResponse() error { | ||||||||||||
return r.invoke(r.StreamingClientConn.CloseResponse) | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (r *recoverStreamingClientConn) invoke(action func() error) (retErr error) { | ||||||||||||
defer func() { | ||||||||||||
if panicVal := recover(); panicVal != nil { | ||||||||||||
var msg any | ||||||||||||
if msgPtr := r.req.Load(); msgPtr != nil { | ||||||||||||
msg = *msgPtr | ||||||||||||
} | ||||||||||||
Comment on lines
+297
to
+300
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
retErr = r.handle(r.ctx, &recoverStreamRequest{r, msg}, panicVal) | ||||||||||||
if retErr == nil { | ||||||||||||
retErr = errorf(CodeInternal, "call panicked; but recover handler returned non-nil error") | ||||||||||||
} | ||||||||||||
} | ||||||||||||
}() | ||||||||||||
return action() | ||||||||||||
} |
Uh oh!
There was an error while loading. Please reload this page.