Skip to content

Commit ddc2000

Browse files
committed
Split udp/write operation into a 3-step beginPacket/write/endPacket
1 parent 3a4b461 commit ddc2000

File tree

2 files changed

+90
-14
lines changed

2 files changed

+90
-14
lines changed

network-api/network-api.go

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ func Register(router *msgpackrouter.Router) {
4747
_ = router.RegisterMethod("tcp/connectSSL", tcpConnectSSL)
4848

4949
_ = router.RegisterMethod("udp/connect", udpConnect)
50+
_ = router.RegisterMethod("udp/beginPacket", udpBeginPacket)
5051
_ = router.RegisterMethod("udp/write", udpWrite)
52+
_ = router.RegisterMethod("udp/endPacket", udpEndPacket)
5153
_ = router.RegisterMethod("udp/awaitRead", udpAwaitRead)
5254
_ = router.RegisterMethod("udp/read", udpRead)
5355
_ = router.RegisterMethod("udp/close", udpClose)
@@ -58,6 +60,8 @@ var liveConnections = make(map[uint]net.Conn)
5860
var liveListeners = make(map[uint]net.Listener)
5961
var liveUdpConnections = make(map[uint]net.PacketConn)
6062
var udpReadBuffers = make(map[uint][]byte)
63+
var udpWriteTargets = make(map[uint]*net.UDPAddr)
64+
var udpWriteBuffers = make(map[uint][]byte)
6165
var nextConnectionID atomic.Uint32
6266

6367
// takeLockAndGenerateNextID generates a new unique ID for a connection or listener.
@@ -375,9 +379,9 @@ func udpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (
375379
return id, nil
376380
}
377381

378-
func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
379-
if len(params) != 4 {
380-
return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port, payload"}
382+
func udpBeginPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
383+
if len(params) != 3 {
384+
return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port"}
381385
}
382386
id, ok := msgpackrpc.ToUint(params[0])
383387
if !ok {
@@ -391,9 +395,33 @@ func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
391395
if !ok {
392396
return nil, []any{1, "Invalid parameter type, expected uint16 for server port"}
393397
}
394-
data, ok := params[3].([]byte)
398+
399+
lock.RLock()
400+
defer lock.RUnlock()
401+
if _, ok := liveUdpConnections[id]; !ok {
402+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
403+
}
404+
targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
405+
addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching
406+
if err != nil {
407+
return nil, []any{3, "Failed to resolve target address: " + err.Error()}
408+
}
409+
udpWriteTargets[id] = addr
410+
udpWriteBuffers[id] = nil
411+
return true, nil
412+
}
413+
414+
func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
415+
if len(params) != 2 {
416+
return nil, []any{1, "Invalid number of parameters, expected expected udpConnId, payload"}
417+
}
418+
id, ok := msgpackrpc.ToUint(params[0])
419+
if !ok {
420+
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
421+
}
422+
data, ok := params[1].([]byte)
395423
if !ok {
396-
if dataStr, ok := params[3].(string); ok {
424+
if dataStr, ok := params[1].(string); ok {
397425
data = []byte(dataStr)
398426
} else {
399427
// If data is not []byte or string, return an error
@@ -402,18 +430,45 @@ func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
402430
}
403431

404432
lock.RLock()
405-
udpConn, ok := liveUdpConnections[id]
433+
udpBuffer, ok := udpWriteBuffers[id]
434+
if ok {
435+
udpWriteBuffers[id] = append(udpBuffer, data...)
436+
}
406437
lock.RUnlock()
407438
if !ok {
408439
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
409440
}
441+
return len(data), nil
442+
}
410443

411-
targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
412-
addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching
413-
if err != nil {
414-
return nil, []any{3, "Failed to resolve target address: " + err.Error()}
444+
func udpEndPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
445+
if len(params) != 1 {
446+
return nil, []any{1, "Invalid number of parameters, expected expected udpConnId"}
415447
}
416-
if n, err := udpConn.WriteTo(data, addr); err != nil {
448+
id, buffExists := msgpackrpc.ToUint(params[0])
449+
if !buffExists {
450+
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
451+
}
452+
453+
var udpBuffer []byte
454+
var udpAddr *net.UDPAddr
455+
lock.RLock()
456+
udpConn, connExists := liveUdpConnections[id]
457+
if connExists {
458+
udpBuffer, buffExists = udpWriteBuffers[id]
459+
udpAddr = udpWriteTargets[id]
460+
delete(udpWriteBuffers, id)
461+
delete(udpWriteTargets, id)
462+
}
463+
lock.RUnlock()
464+
if !connExists {
465+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
466+
}
467+
if !buffExists {
468+
return nil, []any{3, fmt.Sprintf("No UDP packet begun for ID: %d", id)}
469+
}
470+
471+
if n, err := udpConn.WriteTo(udpBuffer, udpAddr); err != nil {
417472
return nil, []any{4, "Failed to write to UDP connection: " + err.Error()}
418473
} else {
419474
return n, nil

network-api/network-api_test.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,13 @@ func TestUDPNetworkAPI(t *testing.T) {
248248
require.NotEqual(t, conn1, conn2)
249249

250250
{
251-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Hello")})
251+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900})
252+
require.Nil(t, err)
253+
require.True(t, res.(bool))
254+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Hello")})
255+
require.Nil(t, err)
256+
require.Equal(t, 5, res)
257+
res, err = udpEndPacket(ctx, nil, []any{conn1})
252258
require.Nil(t, err)
253259
require.Equal(t, 5, res)
254260
}
@@ -262,12 +268,27 @@ func TestUDPNetworkAPI(t *testing.T) {
262268
require.Equal(t, []uint8("Hello"), res2)
263269
}
264270
{
265-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("One")})
271+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900})
272+
require.Nil(t, err)
273+
require.True(t, res.(bool))
274+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("On")})
275+
require.Nil(t, err)
276+
require.Equal(t, 2, res)
277+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("e")})
278+
require.Nil(t, err)
279+
require.Equal(t, 1, res)
280+
res, err = udpEndPacket(ctx, nil, []any{conn1})
266281
require.Nil(t, err)
267282
require.Equal(t, 3, res)
268283
}
269284
{
270-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Two")})
285+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900})
286+
require.Nil(t, err)
287+
require.True(t, res.(bool))
288+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Two")})
289+
require.Nil(t, err)
290+
require.Equal(t, 3, res)
291+
res, err = udpEndPacket(ctx, nil, []any{conn1})
271292
require.Nil(t, err)
272293
require.Equal(t, 3, res)
273294
}

0 commit comments

Comments
 (0)