Skip to content

Return opaque iterator for nearest neighbor searches #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 63 additions & 10 deletions src/rtree/trait.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::marker::PhantomData;

use geo_traits::{CoordTrait, RectTrait};

Expand Down Expand Up @@ -151,7 +152,7 @@ pub trait RTreeIndex<N: IndexableNum>: Sized {
y: N,
max_results: Option<usize>,
max_distance: Option<N>,
) -> Vec<u32> {
) -> impl Iterator<Item = u32> {
let boxes = self.boxes();
let indices = self.indices();
let max_distance = max_distance.unwrap_or(N::max_value());
Expand Down Expand Up @@ -179,13 +180,13 @@ pub trait RTreeIndex<N: IndexableNum>: Sized {

if node_index >= self.num_items() as usize * 4 {
// node (use even id)
queue.push(Reverse(NeighborNode {
queue.push(Reverse(Neighbor {
id: index << 1,
dist,
}));
} else {
// leaf item (use odd id)
queue.push(Reverse(NeighborNode {
queue.push(Reverse(Neighbor {
id: (index << 1) + 1,
dist,
}));
Expand All @@ -212,7 +213,7 @@ pub trait RTreeIndex<N: IndexableNum>: Sized {
}
}

results
results.into_iter()
}

/// Search items in order of distance from the given coordinate.
Expand All @@ -221,7 +222,7 @@ pub trait RTreeIndex<N: IndexableNum>: Sized {
coord: &impl CoordTrait<T = N>,
max_results: Option<usize>,
max_distance: Option<N>,
) -> Vec<u32> {
) -> impl Iterator<Item = u32> {
self.neighbors(coord.x(), coord.y(), max_results, max_distance)
}

Expand All @@ -244,21 +245,28 @@ pub trait RTreeIndex<N: IndexableNum>: Sized {

/// A wrapper around a node and its distance for use in the priority queue.
#[derive(Debug, Clone, Copy, PartialEq)]
struct NeighborNode<N: IndexableNum> {
pub struct Neighbor<N: IndexableNum> {
id: usize,
/// Squared distance
dist: N,
}

impl<N: IndexableNum> Eq for NeighborNode<N> {}
impl<N: IndexableNum> Neighbor<N> {
pub fn insertion_index(&self) -> u32 {
(self.id >> 1).try_into().unwrap()
}
}

impl<N: IndexableNum> Eq for Neighbor<N> {}

impl<N: IndexableNum> Ord for NeighborNode<N> {
impl<N: IndexableNum> Ord for Neighbor<N> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// We don't allow NaN. This should only panic on NaN
self.dist.partial_cmp(&other.dist).unwrap()
}
}

impl<N: IndexableNum> PartialOrd for NeighborNode<N> {
impl<N: IndexableNum> PartialOrd for Neighbor<N> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
Expand Down Expand Up @@ -293,7 +301,6 @@ impl<N: IndexableNum> RTreeIndex<N> for RTreeRef<'_, N> {
}

/// 1D distance from a value to a range.
#[allow(dead_code)]
#[inline]
fn axis_dist<N: IndexableNum>(k: N, min: N, max: N) -> N {
if k < min {
Expand All @@ -305,6 +312,52 @@ fn axis_dist<N: IndexableNum>(k: N, min: N, max: N) -> N {
}
}

struct Neighbors<'a, N: IndexableNum> {
boxes: &'a [N],
indices: Indices<'a>,
outer_node_index: Option<usize>,
queue: BinaryHeap<Reverse<Neighbor<N>>>,
}

impl<'a, N: IndexableNum> Neighbors<'a, N> {
fn new(boxes: &'a [N], indices: Indices<'a>) -> Self {
let outer_node_index = Some(boxes.len() - 4);
let queue = BinaryHeap::new();
Self {
boxes,
indices,
outer_node_index,
queue,
}
}
}

impl<N: IndexableNum> Iterator for Neighbors<'_, N> {
type Item = Neighbor<N>;

fn next(&mut self) -> Option<Self::Item> {
// The queue is not empty and the next item in the queue is a leaf node
if !self.queue.is_empty() && self.queue.peek().is_some_and(|val| (val.0.id & 1) != 0) {
return Some(self.queue.pop().unwrap().0);
};

if let Some(item) = self.queue.pop() {
self.outer_node_index = Some(item.0.id >> 1);
} else {
return None;
}

// Next: check if outer_node_index is not None
// Then: Add child nodes to the queue

if let Some(node_index) = self.outer_node_index {
} else {
None
}
todo!()
}
}

#[cfg(test)]
mod test {
// Replication of tests from flatbush js
Expand Down
Loading