From a6cd31e09078cbf7dba2d7c0c4a458ef7049d573 Mon Sep 17 00:00:00 2001 From: KingCol13 <48412633+KingCol13@users.noreply.github.com> Date: Tue, 24 Jun 2025 19:37:42 +0100 Subject: [PATCH 1/2] Allow raw socket to receive all protocols and versions --- examples/multicast.rs | 4 +- src/iface/interface/tests/ipv4.rs | 11 ++++-- src/socket/raw.rs | 61 +++++++++++++++++-------------- 3 files changed, 44 insertions(+), 32 deletions(-) diff --git a/examples/multicast.rs b/examples/multicast.rs index ab86a4bce..024c034df 100644 --- a/examples/multicast.rs +++ b/examples/multicast.rs @@ -66,8 +66,8 @@ fn main() { // Will not send IGMP let raw_tx_buffer = raw::PacketBuffer::new(vec![], vec![]); let raw_socket = raw::Socket::new( - IpVersion::Ipv4, - IpProtocol::Igmp, + Some(IpVersion::Ipv4), + Some(IpProtocol::Igmp), raw_rx_buffer, raw_tx_buffer, ); diff --git a/src/iface/interface/tests/ipv4.rs b/src/iface/interface/tests/ipv4.rs index a29bd20f8..37685a7ba 100644 --- a/src/iface/interface/tests/ipv4.rs +++ b/src/iface/interface/tests/ipv4.rs @@ -852,7 +852,12 @@ fn test_raw_socket_no_reply(#[case] medium: Medium) { vec![raw::PacketMetadata::EMPTY; packets], vec![0; 48 * packets], ); - let raw_socket = raw::Socket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer); + let raw_socket = raw::Socket::new( + Some(IpVersion::Ipv4), + Some(IpProtocol::Udp), + rx_buffer, + tx_buffer, + ); sockets.add(raw_socket); let src_addr = Ipv4Address::new(127, 0, 0, 2); @@ -948,8 +953,8 @@ fn test_raw_socket_with_udp_socket(#[case] medium: Medium) { vec![0; 48 * packets], ); let raw_socket = raw::Socket::new( - IpVersion::Ipv4, - IpProtocol::Udp, + Some(IpVersion::Ipv4), + Some(IpProtocol::Udp), raw_rx_buffer, raw_tx_buffer, ); diff --git a/src/socket/raw.rs b/src/socket/raw.rs index c79f99b95..13fd07582 100644 --- a/src/socket/raw.rs +++ b/src/socket/raw.rs @@ -80,12 +80,12 @@ pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ()>; /// A raw IP socket. /// -/// A raw socket is bound to a specific IP protocol, and owns +/// A raw socket may be bound to a specific IP protocol, and owns /// transmit and receive packet buffers. #[derive(Debug)] pub struct Socket<'a> { - ip_version: IpVersion, - ip_protocol: IpProtocol, + ip_version: Option, + ip_protocol: Option, rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>, #[cfg(feature = "async")] @@ -98,8 +98,8 @@ impl<'a> Socket<'a> { /// Create a raw IP socket bound to the given IP version and datagram protocol, /// with the given buffers. pub fn new( - ip_version: IpVersion, - ip_protocol: IpProtocol, + ip_version: Option, + ip_protocol: Option, rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>, ) -> Socket<'a> { @@ -152,13 +152,13 @@ impl<'a> Socket<'a> { /// Return the IP version the socket is bound to. #[inline] - pub fn ip_version(&self) -> IpVersion { + pub fn ip_version(&self) -> Option { self.ip_version } /// Return the IP protocol the socket is bound to. #[inline] - pub fn ip_protocol(&self) -> IpProtocol { + pub fn ip_protocol(&self) -> Option { self.ip_protocol } @@ -216,7 +216,7 @@ impl<'a> Socket<'a> { .map_err(|_| SendError::BufferFull)?; net_trace!( - "raw:{}:{}: buffer to send {} octets", + "raw:{:?}:{:?}: buffer to send {} octets", self.ip_version, self.ip_protocol, packet_buf.len() @@ -238,7 +238,7 @@ impl<'a> Socket<'a> { .map_err(|_| SendError::BufferFull)?; net_trace!( - "raw:{}:{}: buffer to send {} octets", + "raw:{:?}:{:?}: buffer to send {} octets", self.ip_version, self.ip_protocol, size @@ -265,7 +265,7 @@ impl<'a> Socket<'a> { let ((), packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; net_trace!( - "raw:{}:{}: receive {} buffered octets", + "raw:{:?}:{:?}: receive {} buffered octets", self.ip_version, self.ip_protocol, packet_buf.len() @@ -299,7 +299,7 @@ impl<'a> Socket<'a> { let ((), packet_buf) = self.rx_buffer.peek().map_err(|_| RecvError::Exhausted)?; net_trace!( - "raw:{}:{}: receive {} buffered octets", + "raw:{:?}:{:?}: receive {} buffered octets", self.ip_version, self.ip_protocol, packet_buf.len() @@ -338,10 +338,17 @@ impl<'a> Socket<'a> { } pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool { - if ip_repr.version() != self.ip_version { + if self + .ip_version + .is_some_and(|version| version != ip_repr.version()) + { return false; } - if ip_repr.next_header() != self.ip_protocol { + + if self + .ip_protocol + .is_some_and(|next_header| next_header != ip_repr.next_header()) + { return false; } @@ -355,7 +362,7 @@ impl<'a> Socket<'a> { let total_len = header_len + payload.len(); net_trace!( - "raw:{}:{}: receiving {} octets", + "raw:{:?}:{:?}: receiving {} octets", self.ip_version, self.ip_protocol, total_len @@ -367,7 +374,7 @@ impl<'a> Socket<'a> { buf[header_len..].copy_from_slice(payload); } Err(_) => net_trace!( - "raw:{}:{}: buffer full, dropped incoming packet", + "raw:{:?}:{:?}: buffer full, dropped incoming packet", self.ip_version, self.ip_protocol ), @@ -395,7 +402,7 @@ impl<'a> Socket<'a> { return Ok(()); } }; - if packet.next_header() != ip_protocol { + if ip_protocol.is_some_and(|next_header| next_header != packet.next_header()) { net_trace!("raw: sent packet with wrong ip protocol, dropping."); return Ok(()); } @@ -415,7 +422,7 @@ impl<'a> Socket<'a> { return Ok(()); } }; - net_trace!("raw:{}:{}: sending", ip_version, ip_protocol); + net_trace!("raw:{:?}:{:?}: sending", ip_version, ip_protocol); emit(cx, (IpRepr::Ipv4(ipv4_repr), packet.payload())) } #[cfg(feature = "proto-ipv6")] @@ -427,7 +434,7 @@ impl<'a> Socket<'a> { return Ok(()); } }; - if packet.next_header() != ip_protocol { + if ip_protocol.is_some_and(|next_header| next_header != packet.next_header()) { net_trace!("raw: sent ipv6 packet with wrong ip protocol, dropping."); return Ok(()); } @@ -440,7 +447,7 @@ impl<'a> Socket<'a> { } }; - net_trace!("raw:{}:{}: sending", ip_version, ip_protocol); + net_trace!("raw:{:?}:{:?}: sending", ip_version, ip_protocol); emit(cx, (IpRepr::Ipv6(ipv6_repr), packet.payload())) } Err(_) => { @@ -495,8 +502,8 @@ mod test { tx_buffer: PacketBuffer<'static>, ) -> Socket<'static> { Socket::new( - IpVersion::Ipv4, - IpProtocol::Unknown(IP_PROTO), + Some(IpVersion::Ipv4), + Some(IpProtocol::Unknown(IP_PROTO)), rx_buffer, tx_buffer, ) @@ -526,8 +533,8 @@ mod test { tx_buffer: PacketBuffer<'static>, ) -> Socket<'static> { Socket::new( - IpVersion::Ipv6, - IpProtocol::Unknown(IP_PROTO), + Some(IpVersion::Ipv6), + Some(IpProtocol::Unknown(IP_PROTO)), rx_buffer, tx_buffer, ) @@ -827,8 +834,8 @@ mod test { #[cfg(feature = "proto-ipv4")] { let socket = Socket::new( - IpVersion::Ipv4, - IpProtocol::Unknown(ipv4_locals::IP_PROTO + 1), + Some(IpVersion::Ipv4), + Some(IpProtocol::Unknown(ipv4_locals::IP_PROTO + 1)), buffer(1), buffer(1), ); @@ -839,8 +846,8 @@ mod test { #[cfg(feature = "proto-ipv6")] { let socket = Socket::new( - IpVersion::Ipv6, - IpProtocol::Unknown(ipv6_locals::IP_PROTO + 1), + Some(IpVersion::Ipv6), + Some(IpProtocol::Unknown(ipv6_locals::IP_PROTO + 1)), buffer(1), buffer(1), ); From 97790597d70968b2b23151d3fc2c832b7b696429 Mon Sep 17 00:00:00 2001 From: KingCol13 <48412633+KingCol13@users.noreply.github.com> Date: Tue, 24 Jun 2025 21:04:14 +0100 Subject: [PATCH 2/2] Add tests for unfiltered raw socket --- src/socket/raw.rs | 90 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/src/socket/raw.rs b/src/socket/raw.rs index 13fd07582..a366b1b28 100644 --- a/src/socket/raw.rs +++ b/src/socket/raw.rs @@ -856,4 +856,94 @@ mod test { assert!(!socket.accepts(&ipv4_locals::HEADER_REPR)); } } + + fn check_dispatch(socket: &mut Socket<'_>, cx: &mut Context) { + // Check dispatch returns Ok(()) and calls the emit closure + let mut emitted = false; + assert_eq!( + socket.dispatch(cx, |_, _| { + emitted = true; + Ok(()) + }), + Ok::<_, ()>(()) + ); + assert!(emitted); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_unfiltered_sends_all(#[case] medium: Medium) { + // Test a single unfiltered socket can send packets with different IP versions and next + // headers + let mut socket = Socket::new(None, None, buffer(0), buffer(2)); + #[cfg(feature = "proto-ipv4")] + { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut udp_packet = ipv4_locals::PACKET_BYTES; + Ipv4Packet::new_unchecked(&mut udp_packet).set_next_header(IpProtocol::Udp); + + assert_eq!(socket.send_slice(&udp_packet), Ok(())); + check_dispatch(&mut socket, cx); + + let mut tcp_packet = ipv4_locals::PACKET_BYTES; + Ipv4Packet::new_unchecked(&mut tcp_packet).set_next_header(IpProtocol::Tcp); + + assert_eq!(socket.send_slice(&tcp_packet[..]), Ok(())); + check_dispatch(&mut socket, cx); + } + #[cfg(feature = "proto-ipv6")] + { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut udp_packet = ipv6_locals::PACKET_BYTES; + Ipv6Packet::new_unchecked(&mut udp_packet).set_next_header(IpProtocol::Udp); + + assert_eq!(socket.send_slice(&ipv6_locals::PACKET_BYTES), Ok(())); + check_dispatch(&mut socket, cx); + + let mut tcp_packet = ipv6_locals::PACKET_BYTES; + Ipv6Packet::new_unchecked(&mut tcp_packet).set_next_header(IpProtocol::Tcp); + + assert_eq!(socket.send_slice(&tcp_packet[..]), Ok(())); + check_dispatch(&mut socket, cx); + } + } + + #[rstest] + #[case::proto(IpProtocol::Icmp)] + #[case::proto(IpProtocol::Tcp)] + #[case::proto(IpProtocol::Udp)] + fn test_unfiltered_accepts_all(#[case] proto: IpProtocol) { + // Test an unfiltered socket can accept packets with different IP versions and next headers + let socket = Socket::new(None, None, buffer(0), buffer(0)); + #[cfg(feature = "proto-ipv4")] + { + let header_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address::new(10, 0, 0, 1), + dst_addr: Ipv4Address::new(10, 0, 0, 2), + next_header: proto, + payload_len: 4, + hop_limit: 64, + }); + assert!(socket.accepts(&header_repr)); + } + #[cfg(feature = "proto-ipv6")] + { + let header_repr = IpRepr::Ipv6(Ipv6Repr { + src_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), + dst_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2), + next_header: proto, + payload_len: 4, + hop_limit: 64, + }); + assert!(socket.accepts(&header_repr)); + } + } }