From 88f778fac569d2723d1f4dac305760b8c31070cc Mon Sep 17 00:00:00 2001 From: Umberto Griffo <1609440+umbertogriffo@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:56:58 +0100 Subject: [PATCH] fix: store key-value pairs in a hashmap while preserving the order of insertion --- .../rust-llm/Cargo.lock | 43 ++++-- .../rust-llm/Cargo.toml | 1 + .../rust-llm/src/bpe.rs | 145 ++++++++++-------- .../rust-llm/src/helpers.rs | 21 ++- 4 files changed, 133 insertions(+), 77 deletions(-) diff --git a/002-Rust-bindings-to-Python/rust-llm/Cargo.lock b/002-Rust-bindings-to-Python/rust-llm/Cargo.lock index 5880cfe..a1ef8b7 100644 --- a/002-Rust-bindings-to-Python/rust-llm/Cargo.lock +++ b/002-Rust-bindings-to-Python/rust-llm/Cargo.lock @@ -26,12 +26,34 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "indexmap" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "indoc" version = "2.0.5" @@ -46,9 +68,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.158" +version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" [[package]] name = "lock_api" @@ -100,9 +122,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce" [[package]] name = "proc-macro2" @@ -187,9 +209,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.3" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +checksum = "62871f2d65009c0256aed1b9cfeeb8ac272833c404e13d53d400cd0dad7a2ac0" dependencies = [ "bitflags", ] @@ -199,6 +221,7 @@ name = "rust_llm" version = "0.1.0" dependencies = [ "base64", + "indexmap", "lazy_static", "pyo3", "unicode_categories", @@ -218,9 +241,9 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "syn" -version = "2.0.76" +version = "2.0.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" dependencies = [ "proc-macro2", "quote", @@ -235,9 +258,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode_categories" diff --git a/002-Rust-bindings-to-Python/rust-llm/Cargo.toml b/002-Rust-bindings-to-Python/rust-llm/Cargo.toml index bf72302..dcb89c7 100644 --- a/002-Rust-bindings-to-Python/rust-llm/Cargo.toml +++ b/002-Rust-bindings-to-Python/rust-llm/Cargo.toml @@ -12,3 +12,4 @@ pyo3 = { version = "0.20.0", features = ["extension-module"] } lazy_static = "1.4.0" base64 = "0.22.1" unicode_categories = "0.1" +indexmap = "2.5.0" diff --git a/002-Rust-bindings-to-Python/rust-llm/src/bpe.rs b/002-Rust-bindings-to-Python/rust-llm/src/bpe.rs index e7de2d7..9479efb 100644 --- a/002-Rust-bindings-to-Python/rust-llm/src/bpe.rs +++ b/002-Rust-bindings-to-Python/rust-llm/src/bpe.rs @@ -5,6 +5,7 @@ use pyo3::prelude::*; use pyo3::prepare_freethreaded_python; +use indexmap::IndexMap; use std::collections::BTreeMap; use std::fs::File; use std::io::{BufRead, BufReader, Write}; @@ -14,9 +15,9 @@ use crate::helpers::{get_stats, merge, render_token, build_vocab, b_as_literal}; // The #[pyclass] macro from PyO3 is used to map this as a Python class #[pyclass] pub struct BPETokenizer { - merges: BTreeMap<(i32, i32), i32>, // (pair, idx) + merges: IndexMap<(i32, i32), i32>, // (pair, idx) pattern: String, // pattern to split the text into tokens - special_tokens: BTreeMap, // special tokens (string, idx) + special_tokens: IndexMap, // special tokens (string, idx) vocab: BTreeMap>, // (idx, token) } @@ -27,10 +28,10 @@ impl BPETokenizer { #[new] pub fn new() -> Self { prepare_freethreaded_python(); - let merges = BTreeMap::<(i32, i32), i32>::new(); - let pattern = String::new(); - let special_tokens = BTreeMap::::new(); - let vocab = build_vocab(&merges, &special_tokens); + let merges: IndexMap<(i32, i32), i32> = IndexMap::<(i32, i32), i32>::new(); + let pattern: String = String::new(); + let special_tokens: IndexMap = IndexMap::::new(); + let vocab: BTreeMap> = build_vocab(&merges, &special_tokens); BPETokenizer { merges, @@ -40,7 +41,7 @@ impl BPETokenizer { } } - fn train(&mut self, text: &str, vocab_size: usize, verbose: bool) { + pub fn train(&mut self, text: &str, vocab_size: usize, verbose: bool) { assert!(vocab_size >= 256); let num_merges = vocab_size - 256; @@ -58,6 +59,9 @@ impl BPETokenizer { for i in 0..num_merges { let stats = get_stats(&ids); + if stats.is_empty() { + break; + } let pair = *stats.iter().max_by_key(|&(_, count)| count).unwrap().0; let idx = 256 + i as i32; ids = merge(&ids, pair, idx); @@ -65,22 +69,26 @@ impl BPETokenizer { self.merges.insert(pair, idx); // Separate immutable borrow before mutable borrow - let pair_1_len = self.vocab[&pair.1].len(); - if let Some(vocab_entry) = self.vocab.get_mut(&pair.0) { - let mut new_token = Vec::with_capacity(vocab_entry.len() + pair_1_len); - new_token.extend_from_slice(vocab_entry); - new_token.extend_from_slice(&self.vocab[&pair.1]); - self.vocab.insert(idx, new_token); + if let (Some(vocab_entry_p0), Some(vocab_entry_p1)) = (self.vocab.get(&pair.0), self.vocab.get(&pair.1)) { + let mut merged_vec = Vec::with_capacity(vocab_entry_p0.len() + vocab_entry_p1.len()); + merged_vec.extend_from_slice(vocab_entry_p0); + merged_vec.extend_from_slice(vocab_entry_p1); + self.vocab.insert(idx, merged_vec); + } + else { + // Handle the case where p0 or p1 does not exist in vocab + eprintln!("Warning: Missing key in vocab for pair ({}, {})", p0, p1); } if verbose { println!( - "merge {}/{}: {:?} -> {} ({:?}) had {} occurrences", + "merge {}/{}: {:?} -> {} ({:?}, binary: {:?}) had {} occurrences", i + 1, num_merges, pair, idx, self.vocab[&idx], + b_as_literal(&self.vocab[&idx]), stats[&pair] ); } @@ -89,6 +97,7 @@ impl BPETokenizer { #[cfg(debug_assertions)] { println!("[DEBUG] Merges: {:?}", self.merges); + println!("[DEBUG] Vocabulary: "); for (key, value) in &self.vocab { let formatted_value = b_as_literal(value); println!("{}: {}", key, formatted_value); @@ -96,10 +105,58 @@ impl BPETokenizer { } } + pub fn encode(&self, text: &str, verbose: bool) -> Vec { + if verbose { + println!("[DEBUG] Encoding text: {}", text); + } + let mut ids: Vec = text.bytes().map(|b| b as i32).collect(); + let merges = &self.merges; + if verbose { + println!("[DEBUG] Initial IDs: {:?}", ids); + println!("[DEBUG] Merges loaded: {:?}", merges); + } + + while ids.len() >= 2 { + let stats = get_stats(&ids); + if verbose { + println!("[DEBUG] Current IDs: {:?}", ids); + println!("[DEBUG] Stats: {:?}", stats); + } + let pair = stats + .keys() + .min_by_key(|&&p| merges.get(&p).unwrap_or(&i32::MAX)) + .copied() + .unwrap(); + + if !merges.contains_key(&pair) { + break; + } + + let new_id = merges[&pair]; + ids = merge(&ids, pair, new_id); + if verbose { + println!("[DEBUG] Pair: {:?}", pair); + println!("[DEBUG] New ID for merge: {}", new_id); + } + } + + if verbose { + println!("[DEBUG] Encoding complete. Final IDs: {:?}", ids); + } + ids + } + pub fn decode(&self, ids: Vec) -> String { let mut text_bytes = Vec::with_capacity(ids.len() * 2); // Pre-allocate assuming 2 bytes per id on average + for &id in &ids { - text_bytes.extend_from_slice(&self.vocab[&id]); + if let Some(vocab_entry) = self.vocab.get(&id) { + text_bytes.extend_from_slice(vocab_entry); + } + else { + // Handle the case where p0 or p1 does not exist in vocab + eprintln!("Warning: Missing key in vocab for id {}", id); + } } String::from_utf8_lossy(&text_bytes).into_owned() } @@ -121,7 +178,7 @@ impl BPETokenizer { let vocab_file = format!("{}.vocab", file_prefix); let mut vocab_file = File::create(&vocab_file)?; - let inverted_merges: BTreeMap = self.merges + let inverted_merges: IndexMap = self.merges .iter() .map(|(&(p0, p1), &idx)| (idx, (p0, p1))) .collect(); @@ -131,9 +188,9 @@ impl BPETokenizer { if let Some(&(idx0, idx1)) = inverted_merges.get(&idx) { let s0 = render_token(&self.vocab[&idx0]); let s1 = render_token(&self.vocab[&idx1]); - writeln!(vocab_file, "[{}][{}] -> [{}] {} \n", s0, s1, s, idx)?; + writeln!(vocab_file, "[{}][{}] -> [{}] {}", s0, s1, s, idx)?; } else { - writeln!(vocab_file, "[{}] {} \n", s, idx)?; + writeln!(vocab_file, "[{}] {}", s, idx)?; } } Ok(()) @@ -154,7 +211,7 @@ impl BPETokenizer { self.pattern = lines.next().unwrap().unwrap().trim().to_string(); let num_special_tokens = lines.next().unwrap().unwrap().parse::().unwrap(); - let mut special_tokens = BTreeMap::::new(); + let mut special_tokens = IndexMap::::new(); for _ in 0..num_special_tokens { let line = lines.next().unwrap().unwrap(); @@ -162,7 +219,7 @@ impl BPETokenizer { special_tokens.insert(parts[0].to_string(), parts[1].parse::().unwrap()); } - let mut merges = BTreeMap::<(i32, i32), i32>::new(); + let mut merges = IndexMap::<(i32, i32), i32>::new(); let mut idx = 256; for line in lines { let line = line.unwrap(); @@ -173,50 +230,18 @@ impl BPETokenizer { idx += 1; } + if verbose { + for ((p0, p1), idx) in &merges { + println!("({},{}): {}", p0, p1, idx); + } + } + self.merges = merges; + self.special_tokens = special_tokens; self.vocab = build_vocab(&self.merges, &self.special_tokens); Ok(()) } - pub fn encode(&self, text: &str, verbose: bool) -> Vec { - if verbose { - println!("[DEBUG] Encoding text: {}", text); - } - let mut ids: Vec = text.bytes().map(|b| b as i32).collect(); - let merges = &self.merges; - if verbose { - println!("[DEBUG] Initial IDs: {:?}", ids); - println!("[DEBUG] Merges loaded: {:?}", merges); - } - - while ids.len() >= 2 { - let stats = get_stats(&ids); - if verbose { - println!("[DEBUG] Current IDs: {:?}", ids); - println!("[DEBUG] Stats: {:?}", stats); - } - let pair = stats - .keys() - .min_by_key(|&&p| merges.get(&p).unwrap_or(&i32::MAX)) - .copied() - .unwrap(); - - if !merges.contains_key(&pair) { - break; - } - - let new_id = merges[&pair]; - ids = merge(&ids, pair, new_id); - if verbose { - println!("[DEBUG] Pair: {:?}", pair); - println!("[DEBUG] New ID for merge: {}", new_id); - } - } - if verbose { - println!("[DEBUG] Encoding complete. Final IDs: {:?}", ids); - } - ids - } -} \ No newline at end of file +} diff --git a/002-Rust-bindings-to-Python/rust-llm/src/helpers.rs b/002-Rust-bindings-to-Python/rust-llm/src/helpers.rs index 58a2282..d3cfcd6 100644 --- a/002-Rust-bindings-to-Python/rust-llm/src/helpers.rs +++ b/002-Rust-bindings-to-Python/rust-llm/src/helpers.rs @@ -1,12 +1,14 @@ +use std::collections::HashMap; use std::collections::BTreeMap; +use indexmap::IndexMap; use std::str; // Helpers // Given a list of integers, return a HashMap of counts of consecutive pairs // Example: vec[1,2,3,1,2] -> HMap {(1,2): 2, (2,3): 1, (3,1): 1} -pub fn get_stats(ids: &[i32]) -> BTreeMap<(i32, i32), usize> { - let mut counts = BTreeMap::new(); +pub fn get_stats(ids: &[i32]) -> HashMap<(i32, i32), usize> { + let mut counts = HashMap::new(); for pair in ids.windows(2) { let pair = (pair[0], pair[1]); *counts.entry(pair).or_insert(0) += 1; @@ -29,15 +31,20 @@ pub fn merge(ids: &[i32], pair: (i32, i32), idx: i32) -> Vec { newids } -pub fn build_vocab(merges: &BTreeMap<(i32, i32), i32>, special_tokens: &BTreeMap) -> BTreeMap> { +pub fn build_vocab(merges: &IndexMap<(i32, i32), i32>, special_tokens: &IndexMap) -> BTreeMap> { + // that base vocabulary will contain all the ASCII characters let mut vocab: BTreeMap> = (0..256).map(|idx| (idx as i32, vec![idx as u8])).collect(); for (&(p0, p1), &idx) in merges { - if let Some(vocab_entry) = vocab.get(&p0) { - let mut merged_vec = Vec::with_capacity(vocab_entry.len() + vocab[&p1].len()); - merged_vec.extend_from_slice(vocab_entry); - merged_vec.extend_from_slice(&vocab[&p1]); + if let (Some(vocab_entry_p0), Some(vocab_entry_p1)) = (vocab.get(&p0), vocab.get(&p1)) { + let mut merged_vec = Vec::with_capacity(vocab_entry_p0.len() + vocab_entry_p1.len()); + merged_vec.extend_from_slice(vocab_entry_p0); + merged_vec.extend_from_slice(vocab_entry_p1); vocab.insert(idx, merged_vec); } + else { + // Handle the case where p0 or p1 does not exist in vocab + eprintln!("Warning: Missing key in vocab for pair ({}, {})", p0, p1); + } } for (special, &idx) in special_tokens {