Skip to content

Commit a5f270d

Browse files
authored
Fixes for consumer group (#1022)
* parse join group response meta based on version in metadata * decode sync version properly * check for assigment in syncgroup before unmarshaling * fix heartbeat reqs * test for heartbeat triggering generation end * Client.JoinGroup sarama compatibility test * Client.SyncGroup: teset to ensure v0 compatibility * Generation.OffsetCommit: return errors from api responses
1 parent cf40a01 commit a5f270d

File tree

8 files changed

+284
-9
lines changed

8 files changed

+284
-9
lines changed

consumergroup.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,15 @@ func (g *Generation) CommitOffsets(offsets map[string]map[int]int64) error {
432432
Topics: topics,
433433
}
434434

435-
_, err := g.coord.offsetCommit(genCtx{g}, request)
435+
resp, err := g.coord.offsetCommit(genCtx{g}, request)
436436
if err == nil {
437+
for _, partitions := range resp.Topics {
438+
for _, partition := range partitions {
439+
if partition.Error != nil {
440+
return partition.Error
441+
}
442+
}
443+
}
437444
// if logging is enabled, print out the partitions that were committed.
438445
g.log(func(l Logger) {
439446
var report []string
@@ -470,14 +477,17 @@ func (g *Generation) heartbeatLoop(interval time.Duration) {
470477
case <-ctx.Done():
471478
return
472479
case <-ticker.C:
473-
_, err := g.coord.heartbeat(ctx, &HeartbeatRequest{
480+
resp, err := g.coord.heartbeat(ctx, &HeartbeatRequest{
474481
GroupID: g.GroupID,
475482
GenerationID: g.ID,
476483
MemberID: g.MemberID,
477484
})
478485
if err != nil {
479486
return
480487
}
488+
if resp.Error != nil {
489+
return
490+
}
481491
}
482492
}
483493
})
@@ -1091,6 +1101,9 @@ func (cg *ConsumerGroup) fetchOffsets(subs map[string][]int) (map[string]map[int
10911101
for topic, offsets := range offsets.Topics {
10921102
offsetsByPartition := map[int]int64{}
10931103
for _, pr := range offsets {
1104+
if pr.Error != nil {
1105+
return nil, pr.Error
1106+
}
10941107
if pr.CommittedOffset < 0 {
10951108
pr.CommittedOffset = cg.config.StartOffset
10961109
}
@@ -1137,14 +1150,17 @@ func (cg *ConsumerGroup) leaveGroup(ctx context.Context, memberID string) error
11371150
log.Printf("Leaving group %s, member %s", cg.config.ID, memberID)
11381151
})
11391152

1140-
_, err := cg.coord.leaveGroup(ctx, &LeaveGroupRequest{
1153+
resp, err := cg.coord.leaveGroup(ctx, &LeaveGroupRequest{
11411154
GroupID: cg.config.ID,
11421155
Members: []LeaveGroupRequestMember{
11431156
{
11441157
ID: memberID,
11451158
},
11461159
},
11471160
})
1161+
if err == nil && resp.Error != nil {
1162+
err = resp.Error
1163+
}
11481164
if err != nil {
11491165
cg.withErrorLogger(func(log Logger) {
11501166
log.Printf("leave group failed for group, %v, and member, %v: %v", cg.config.ID, memberID, err)

consumergroup_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,96 @@ func TestGenerationStartsFunctionAfterClosed(t *testing.T) {
606606
}
607607
}
608608
}
609+
610+
func TestGenerationEndsOnHeartbeatError(t *testing.T) {
611+
gen := Generation{
612+
coord: &mockCoordinator{
613+
heartbeatFunc: func(context.Context, *HeartbeatRequest) (*HeartbeatResponse, error) {
614+
return nil, errors.New("some error")
615+
},
616+
},
617+
done: make(chan struct{}),
618+
joined: make(chan struct{}),
619+
log: func(func(Logger)) {},
620+
logError: func(func(Logger)) {},
621+
}
622+
623+
ch := make(chan error)
624+
gen.Start(func(ctx context.Context) {
625+
<-ctx.Done()
626+
ch <- ctx.Err()
627+
})
628+
629+
gen.heartbeatLoop(time.Millisecond)
630+
631+
select {
632+
case <-time.After(time.Second):
633+
t.Fatal("timed out waiting for func to run")
634+
case err := <-ch:
635+
if !errors.Is(err, ErrGenerationEnded) {
636+
t.Fatalf("expected %v but got %v", ErrGenerationEnded, err)
637+
}
638+
}
639+
}
640+
641+
func TestGenerationEndsOnHeartbeatRebalaceInProgress(t *testing.T) {
642+
gen := Generation{
643+
coord: &mockCoordinator{
644+
heartbeatFunc: func(context.Context, *HeartbeatRequest) (*HeartbeatResponse, error) {
645+
return &HeartbeatResponse{
646+
Error: makeError(int16(RebalanceInProgress), ""),
647+
}, nil
648+
},
649+
},
650+
done: make(chan struct{}),
651+
joined: make(chan struct{}),
652+
log: func(func(Logger)) {},
653+
logError: func(func(Logger)) {},
654+
}
655+
656+
ch := make(chan error)
657+
gen.Start(func(ctx context.Context) {
658+
<-ctx.Done()
659+
ch <- ctx.Err()
660+
})
661+
662+
gen.heartbeatLoop(time.Millisecond)
663+
664+
select {
665+
case <-time.After(time.Second):
666+
t.Fatal("timed out waiting for func to run")
667+
case err := <-ch:
668+
if !errors.Is(err, ErrGenerationEnded) {
669+
t.Fatalf("expected %v but got %v", ErrGenerationEnded, err)
670+
}
671+
}
672+
}
673+
674+
func TestGenerationOffsetCommitErrorsAreReturned(t *testing.T) {
675+
mc := mockCoordinator{
676+
offsetCommitFunc: func(context.Context, *OffsetCommitRequest) (*OffsetCommitResponse, error) {
677+
return &OffsetCommitResponse{
678+
Topics: map[string][]OffsetCommitPartition{
679+
"topic": {
680+
{
681+
Error: ErrGenerationEnded,
682+
},
683+
},
684+
},
685+
}, nil
686+
},
687+
}
688+
gen := Generation{
689+
coord: mc,
690+
log: func(func(Logger)) {},
691+
}
692+
693+
err := gen.CommitOffsets(map[string]map[int]int64{
694+
"topic": {
695+
0: 100,
696+
},
697+
})
698+
if err == nil {
699+
t.Fatal("got nil from CommitOffsets when expecting an error")
700+
}
701+
}

joingroup.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package kafka
33
import (
44
"bufio"
55
"context"
6+
"errors"
67
"fmt"
8+
"io"
79
"net"
810
"time"
911

@@ -163,7 +165,9 @@ func (c *Client) JoinGroup(ctx context.Context, req *JoinGroupRequest) (*JoinGro
163165

164166
for _, member := range r.Members {
165167
var meta consumer.Subscription
166-
err = protocol.Unmarshal(member.Metadata, consumer.MaxVersionSupported, &meta)
168+
metaVersion := makeInt16(member.Metadata[0:2])
169+
err = protocol.Unmarshal(member.Metadata, metaVersion, &meta)
170+
err = joinGroupSubscriptionMetaError(err, metaVersion)
167171
if err != nil {
168172
return nil, fmt.Errorf("kafka.(*Client).JoinGroup: %w", err)
169173
}
@@ -188,6 +192,16 @@ func (c *Client) JoinGroup(ctx context.Context, req *JoinGroupRequest) (*JoinGro
188192
return res, nil
189193
}
190194

195+
// sarama indicates there are some misbehaving clients out there that
196+
// set the version as 1 but don't include the OwnedPartitions section
197+
// https://github.com/Shopify/sarama/blob/610514edec1825240d59b62e4d7f1aba4b1fa000/consumer_group_members.go#L43
198+
func joinGroupSubscriptionMetaError(err error, version int16) error {
199+
if version >= 1 && errors.Is(err, io.ErrUnexpectedEOF) {
200+
return nil
201+
}
202+
return err
203+
}
204+
191205
type groupMetadata struct {
192206
Version int16
193207
Topics []string

joingroup_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@ import (
55
"bytes"
66
"context"
77
"errors"
8+
"net"
89
"reflect"
910
"testing"
1011
"time"
1112

13+
"github.com/segmentio/kafka-go/protocol"
14+
"github.com/segmentio/kafka-go/protocol/consumer"
15+
"github.com/segmentio/kafka-go/protocol/joingroup"
1216
ktesting "github.com/segmentio/kafka-go/testing"
1317
)
1418

@@ -124,6 +128,84 @@ func TestClientJoinGroup(t *testing.T) {
124128
}
125129
}
126130

131+
type roundTripFn func(context.Context, net.Addr, Request) (Response, error)
132+
133+
func (f roundTripFn) RoundTrip(ctx context.Context, addr net.Addr, req Request) (Response, error) {
134+
return f(ctx, addr, req)
135+
}
136+
137+
// https://github.com/Shopify/sarama/blob/610514edec1825240d59b62e4d7f1aba4b1fa000/consumer_group_members.go#L43
138+
func TestClientJoinGroupSaramaCompatibility(t *testing.T) {
139+
subscription := consumer.Subscription{
140+
Version: 1,
141+
Topics: []string{"topic"},
142+
}
143+
144+
// Marhsal as Verzon 0 (Without OwnedPartitions) but
145+
// with Version=1.
146+
metadata, err := protocol.Marshal(0, subscription)
147+
if err != nil {
148+
t.Fatalf("failed to marshal subscription %v", err)
149+
}
150+
151+
client := &Client{
152+
Addr: TCP("fake:9092"),
153+
Transport: roundTripFn(func(context.Context, net.Addr, Request) (Response, error) {
154+
resp := joingroup.Response{
155+
ProtocolType: "consumer",
156+
ProtocolName: RoundRobinGroupBalancer{}.ProtocolName(),
157+
LeaderID: "member",
158+
MemberID: "member",
159+
Members: []joingroup.ResponseMember{
160+
{
161+
MemberID: "member",
162+
Metadata: metadata,
163+
},
164+
},
165+
}
166+
return &resp, nil
167+
}),
168+
}
169+
170+
expResp := JoinGroupResponse{
171+
ProtocolName: RoundRobinGroupBalancer{}.ProtocolName(),
172+
ProtocolType: "consumer",
173+
LeaderID: "member",
174+
MemberID: "member",
175+
Members: []JoinGroupResponseMember{
176+
{
177+
ID: "member",
178+
Metadata: GroupProtocolSubscription{
179+
Topics: []string{"topic"},
180+
OwnedPartitions: map[string][]int{},
181+
},
182+
},
183+
},
184+
}
185+
186+
gotResp, err := client.JoinGroup(context.Background(), &JoinGroupRequest{
187+
GroupID: "group",
188+
MemberID: "member",
189+
ProtocolType: "consumer",
190+
Protocols: []GroupProtocol{
191+
{
192+
Name: RoundRobinGroupBalancer{}.ProtocolName(),
193+
Metadata: GroupProtocolSubscription{
194+
Topics: []string{"topic"},
195+
UserData: metadata,
196+
},
197+
},
198+
},
199+
})
200+
if err != nil {
201+
t.Fatalf("error calling JoinGroup: %v", err)
202+
}
203+
204+
if !reflect.DeepEqual(expResp, *gotResp) {
205+
t.Fatalf("unexpected JoinGroup resp\nexpected: %#v\n got: %#v", expResp, *gotResp)
206+
}
207+
}
208+
127209
func TestSaramaCompatibility(t *testing.T) {
128210
var (
129211
// sample data from github.com/Shopify/sarama

protocol/heartbeat/heartbeat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ type Response struct {
2727
// type.
2828
_ struct{} `kafka:"min=v4,max=v4,tag"`
2929

30-
ErrorCode int16 `kafka:"min=v0,max=v4"`
3130
ThrottleTimeMs int32 `kafka:"min=v1,max=v4"`
31+
ErrorCode int16 `kafka:"min=v0,max=v4"`
3232
}
3333

3434
func (r *Response) ApiKey() protocol.ApiKey {

reader_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1684,7 +1684,7 @@ func TestConsumerGroupMultipleWithDefaultTransport(t *testing.T) {
16841684
recvErr2 <- err
16851685
}()
16861686

1687-
time.Sleep(conf1.MaxWait)
1687+
time.Sleep(conf1.MaxWait * 5)
16881688

16891689
totalMessages := 10
16901690

syncgroup.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,13 @@ func (c *Client) SyncGroup(ctx context.Context, req *SyncGroupRequest) (*SyncGro
127127
r := m.(*syncgroup.Response)
128128

129129
var assignment consumer.Assignment
130-
err = protocol.Unmarshal(r.Assignments, consumer.MaxVersionSupported, &assignment)
131-
if err != nil {
132-
return nil, fmt.Errorf("kafka.(*Client).SyncGroup: %w", err)
130+
var metaVersion int16
131+
if len(r.Assignments) > 2 {
132+
metaVersion = makeInt16(r.Assignments[0:2])
133+
err = protocol.Unmarshal(r.Assignments, metaVersion, &assignment)
134+
if err != nil {
135+
return nil, fmt.Errorf("kafka.(*Client).SyncGroup: %w", err)
136+
}
133137
}
134138

135139
res := &SyncGroupResponse{

0 commit comments

Comments
 (0)