diff --git a/src/iface/interface/tests/ipv4.rs b/src/iface/interface/tests/ipv4.rs index 37685a7ba..154419e8b 100644 --- a/src/iface/interface/tests/ipv4.rs +++ b/src/iface/interface/tests/ipv4.rs @@ -1029,6 +1029,76 @@ fn test_raw_socket_with_udp_socket(#[case] medium: Medium) { ); } +#[rstest] +#[case(Medium::Ip)] +#[cfg(all( + feature = "socket-raw", + feature = "proto-ipv4-fragmentation", + feature = "medium-ip" +))] +#[case(Medium::Ethernet)] +#[cfg(all( + feature = "socket-raw", + feature = "proto-ipv4-fragmentation", + feature = "medium-ethernet" +))] +fn test_raw_socket_tx_fragmentation(#[case] medium: Medium) { + use std::panic::AssertUnwindSafe; + + let (mut iface, mut sockets, device) = setup(medium); + let mtu = device.capabilities().max_transmission_unit; + + let packets = 5; + let rx_buffer = raw::PacketBuffer::new( + vec![raw::PacketMetadata::EMPTY; packets], + vec![0; mtu * packets], + ); + let tx_buffer = raw::PacketBuffer::new( + vec![raw::PacketMetadata::EMPTY; packets], + vec![0; mtu * packets], + ); + let socket = raw::Socket::new( + Some(IpVersion::Ipv4), + Some(IpProtocol::Udp), + rx_buffer, + tx_buffer, + ); + let _handle = sockets.add(socket); + + let tx_packet_sizes = vec![ + mtu * 3 / 4, // Smaller than MTU + mtu * 5 / 4, // Larger than MTU, requires fragmentation + mtu * 9 / 4, // Much larger, requires two fragments + ]; + for packet_size in tx_packet_sizes { + let payload_len = packet_size - IPV4_HEADER_LEN; + let payload = vec![0u8; payload_len]; + + let ip_repr = Ipv4Repr { + src_addr: Ipv4Address::new(192, 168, 1, 3), + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Unknown(92), + hop_limit: 64, + payload_len, + }; + let ip_payload = IpPayload::Raw(&payload); + let packet = Packet::new_ipv4(ip_repr, ip_payload); + + // This should not panic for any payload size + let result = std::panic::catch_unwind(AssertUnwindSafe(|| { + iface.inner.dispatch_ip( + MockTxToken {}, + PacketMeta::default(), + packet, + &mut iface.fragmenter, + ) + })); + + // All transmissions should succeed without panicking + assert!(result.is_ok(), "Failed for packet size: {}", packet_size,); + } +} + #[rstest] #[case(Medium::Ip)] #[cfg(all(feature = "socket-udp", feature = "medium-ip"))] diff --git a/src/iface/packet.rs b/src/iface/packet.rs index 70e5a5ccf..287687ea9 100644 --- a/src/iface/packet.rs +++ b/src/iface/packet.rs @@ -130,7 +130,10 @@ impl<'p> Packet<'p> { } #[cfg(feature = "socket-raw")] - IpPayload::Raw(raw_packet) => payload.copy_from_slice(raw_packet), + IpPayload::Raw(raw_packet) => { + let len = raw_packet.len(); + payload[..len].copy_from_slice(raw_packet) + } #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] IpPayload::Udp(udp_repr, inner_payload) => udp_repr.emit( &mut UdpPacket::new_unchecked(payload),