From 1e31f023bdc3c5e712a0c32e51a1bcf1c4e188ca Mon Sep 17 00:00:00 2001 From: Kelvin Ly Date: Fri, 12 May 2023 17:05:00 -0400 Subject: [PATCH] Add some more utility functions; basic testing shows it works about the same as the original Javascript, after fixing that off by one bug --- wordle_opt/src/lib.rs | 435 +++++++++++++++++++++++++----------------- wordle_shim.js | 44 ++++- 2 files changed, 301 insertions(+), 178 deletions(-) diff --git a/wordle_opt/src/lib.rs b/wordle_opt/src/lib.rs index b13ba4b..09d37f0 100644 --- a/wordle_opt/src/lib.rs +++ b/wordle_opt/src/lib.rs @@ -1,5 +1,3 @@ -use std::alloc::{alloc, dealloc, Layout}; - extern crate wee_alloc; extern { @@ -11,185 +9,13 @@ extern { #[global_allocator] static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT; -const LIST_END: u16 = 65535; - -fn unwrap_or_abort(v: Option) -> T { - match v { - Some(v) => v, - None => std::process::abort(), - } -} - -unsafe fn alloc_ary(sz: usize) -> &'static mut [T] { - let layout = unwrap_or_abort(Layout::from_size_align( - sz*std::mem::size_of::(), - std::mem::align_of::() - ).ok()); - let ptr = alloc(layout) as *mut T; - std::slice::from_raw_parts_mut(ptr, sz) -} - -unsafe fn dealloc_ary<'a, T>(v: &'a mut [T]) { - let layout = unwrap_or_abort(Layout::from_size_align( - v.len()*std::mem::size_of::(), - std::mem::align_of::() - ).ok()); - dealloc(v.as_ptr() as *mut u8, layout); -} - -// NOTE: you must use this like a = realloc_ary(a, 100) -// otherwise you will suffer from use after free -unsafe fn realloc_ary<'a, T>(v: &'a mut [T], sz: usize) -> &'static mut [T] { - if v.len() == sz { - unsafe { std::slice::from_raw_parts_mut(v.as_mut_ptr(), v.len()) } - } else { - if v.len() > 0 { - unsafe { dealloc_ary(v) }; - } - unsafe { alloc_ary(sz) } - } -} +const WORD_SZ: usize = 5; // webassembly version of all the wordle logic // because the javascript version is slow as hell on firefox - -// TODO wrap all this in a struct to reduce the insane amount of -// unsafe in the code -pub static mut LOOKUP: &'static mut [u8] = &mut []; -pub static mut STRINGS: &'static mut [u8] = &mut []; -pub static mut STR_IDXS: &'static mut [*const u8] = &mut []; -pub static mut VALID_STRS: &'static mut [u16] = &mut []; - -// TODO maybe fiddle with the lifetimes eventually -pub struct Solver { - lookup: &'static mut [u8], - strings: &'static mut [u8], - idxs: &'static mut [*const u8], - valid_words: &'static mut [u16], -} - -impl Solver { - fn new() -> Solver { - Solver { - strings: &mut [], - idxs: &mut [], - - valid_words: &mut [], - lookup: &mut [], - } - } - - // NOTE: calls fill_string to get the string from Javascript - fn init(&mut self, str_sz: usize) { - unsafe { - self.strings = realloc_ary(self.strings, str_sz); - fill_string(self.strings.as_mut_ptr()); - self.init_index(); - log_num_idxs(self.idxs.len()); - self.reset(); - } - } - - fn init_index(&mut self) { - let mut last_alpha = false; - let mut num_words = 0; - for v in self.strings.iter() { - let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8); - if cur_alpha && !last_alpha { - num_words += 1; - } - last_alpha = cur_alpha; - } - - unsafe { - self.idxs = realloc_ary(self.idxs, num_words); - } - - let mut idx = 0; - for (i, v) in self.strings.iter().enumerate() { - let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8); - if cur_alpha && !last_alpha { - self.idxs[idx] = &self.strings[i] as *const u8; - idx += 1; - } - last_alpha = cur_alpha; - } - } - - fn reset(&mut self) { - unsafe { - self.valid_words = realloc_ary(self.valid_words, self.idxs.len()); - } - for (i, v) in self.valid_words.iter_mut().enumerate() { - *v = i as u16; - } - } - - fn count_valid_words(&self) -> usize { - self.valid_words.iter().take_while(|&v| *v != LIST_END).count() - } - - fn eliminate_words(&mut self, guess_idx: usize, guess_result: u8) { - let num_valid_words = self.count_valid_words(); - let mut idx = 0; - for i in 0..num_valid_words { - if self.lookup[guess_idx * self.idxs.len() + i] == guess_result { - self.valid_words[idx] = self.valid_words[i]; - } - } - if idx != self.valid_words.len() { - self.valid_words[idx] = LIST_END; - } - } -} - #[no_mangle] -pub fn idxs_offset(s: &Solver) -> *const *const u8 { - s.idxs.as_ptr() -} - -#[no_mangle] -pub extern fn init(str_sz: usize) -> *mut Solver { - let s = Box::new(Solver::new()); - s.init(str_sz); - let ret = unsafe { s.as_ptr_mut() }; - std::mem::forget(s); - ret -} - -#[no_mangle] -fn precalc(num_words: usize) { - let mut i = unsafe { IDX }; - let idxs = unsafe { &*STR_IDXS }; - let lookup = unsafe { &mut *LOOKUP }; - let l = idxs.len(); - - for _ in 0..num_words { - if i >= idxs.len() { - break; - } - let guess = idxs[i]; - for (j, reference) in idxs.iter().enumerate() { - lookup[i*l + j] = calc_match(guess, *reference); - } - i += 1; - } - - unsafe { - IDX = i; - } -} - -#[no_mangle] -fn precalc_done() -> bool { - unsafe { - IDX >= STR_IDXS.len() - } -} - -#[no_mangle] -fn calc_match(guess: *const u8, reference: *const u8) -> u8 { - let mut unmatched = vec![0u8; 26]; +pub extern fn calc_match(guess: *const u8, reference: *const u8, unmatched: &mut [u8; 26]) -> u8 { + unmatched.fill(0); let mut ret = 0; let mut matched = 0; @@ -221,6 +47,261 @@ fn calc_match(guess: *const u8, reference: *const u8) -> u8 { ret } + +pub struct Solver { + strings: Vec, + idxs: Vec<*const u8>, + valid_words: Vec, + lookup: Vec, + + precalc_idx: usize, + entropy_scratch: [u16; 3*3*3*3*3], +} + +impl Solver { + fn new() -> Solver { + Solver { + strings: Vec::new(), + idxs: Vec::new(), + valid_words: Vec::new(), + + lookup: Vec::new(), + precalc_idx: 0, + entropy_scratch: [0u16; 3*3*3*3*3], + } + } + + // NOTE: calls fill_string to get the string from Javascript + fn init(&mut self, str_sz: usize) { + self.strings.resize(str_sz, 0u8); + unsafe { fill_string(self.strings.as_mut_ptr()); } + self.init_index(); + let num_words = self.idxs.len(); + unsafe { log_num_idxs(num_words); } + self.reset(); + self.lookup.resize(num_words*num_words, 0); + } + + fn init_index(&mut self) { + let mut last_alpha = false; + let mut num_words = 0; + for v in self.strings.iter() { + let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8); + if cur_alpha && !last_alpha { + num_words += 1; + } + last_alpha = cur_alpha; + } + + self.idxs.resize(num_words, 0 as *const u8); + + let mut idx = 0; + for (i, v) in self.strings.iter().enumerate() { + let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8); + if cur_alpha && !last_alpha { + self.idxs[idx] = &self.strings[i] as *const u8; + idx += 1; + } + last_alpha = cur_alpha; + } + } + + fn reset(&mut self) { + self.valid_words.resize(self.idxs.len(), 65535); + for (i, v) in self.valid_words.iter_mut().enumerate() { + *v = i as u16; + } + } + + fn eliminate_words(&mut self, guess_idx: usize, guess_result: u8) { + let num_valid_words = self.valid_words.len(); + let mut idx = 0; + for i in 0..num_valid_words { + if self.lookup[guess_idx * self.idxs.len() + self.valid_words[i] as usize] == guess_result { + self.valid_words[idx] = self.valid_words[i]; + idx += 1; + } + } + self.valid_words.resize(idx, 65535); + } + + fn precalc(&mut self, num_words: usize) -> bool { + let mut tmp = [0u8; 26]; + let len = self.idxs.len(); + for _ in 0..num_words { + if self.precalc_idx >= len { + break; + } + + let i = self.precalc_idx; + for (j, v) in self.idxs.iter().enumerate() { + self.lookup[i * len + j] = calc_match(self.idxs[i], *v, &mut tmp); + } + self.precalc_idx += 1; + } + self.precalc_idx >= len + } + + fn calc_entropy(&mut self, guess_idx: usize) -> f32 { + self.entropy_scratch.fill(0); + let l = self.idxs.len(); + for idx in &self.valid_words { + self.entropy_scratch[self.lookup[guess_idx*l + (*idx as usize)] as usize] += 1; + } + let mut entropy = 0f32; + let recip = 1.0f32/(self.valid_words.len() as f32); + for v in &self.entropy_scratch { + if *v != 0 { + let prob = (*v as f32)*recip; + entropy -= prob.log2() * prob; + } + } + entropy + } + + fn best_word(&mut self) -> (isize, f32) { + if self.valid_words.len() == 1 { + return (self.valid_words[0] as isize, 0f32); + } + let mut cur_best = -1; + let mut cur_entropy = 0f32; + for i in 0..self.idxs.len() { + let entropy = self.calc_entropy(i); + if entropy > cur_entropy { + cur_entropy = entropy; + cur_best = i as isize; + } + } + (cur_best, cur_entropy) + } +} + +/* +#[no_mangle] +pub fn idxs_offset(s: &Solver) -> *const *const u8 { + (&s.idxs[0]).as_ptr() +} +*/ + +#[no_mangle] +pub extern fn init(str_sz: usize) -> *mut Solver { + let mut s = Box::new(Solver::new()); + s.init(str_sz); + Box::into_raw(s) +} + +#[no_mangle] +pub extern fn reset(s: *mut Solver) { + let solver: &mut Solver = unsafe { &mut *s }; + solver.reset(); +} + +#[no_mangle] +pub extern fn eliminate_words(s: *mut Solver, guess_idx: usize, guess_result: u8) { + let solver: &mut Solver = unsafe { &mut *s }; + solver.eliminate_words(guess_idx, guess_result); +} + +#[no_mangle] +pub extern fn words_left(s: *const Solver) -> usize { + let solver: &Solver = unsafe { &*s }; + solver.valid_words.len() +} + +#[no_mangle] +pub extern fn calc_entropy(s: *mut Solver, idx: usize) -> f32{ + let solver: &mut Solver = unsafe { &mut *s }; + solver.calc_entropy(idx) +} + + +#[no_mangle] +pub extern fn precalc(s: *mut Solver, num_words: usize) -> bool { + let solver: &mut Solver = unsafe { &mut *s }; + solver.precalc(num_words) +} + +#[no_mangle] +pub extern fn get_precalc(s: *const Solver) -> usize { + let solver: &Solver = unsafe { &*s }; + solver.precalc_idx +} + +#[no_mangle] +pub extern fn precalc_done(s: *const Solver) -> bool { + let solver: &Solver = unsafe { &*s }; + solver.precalc_idx >= solver.idxs.len() +} + +#[no_mangle] +pub extern fn lookup_word(s: *const Solver, idx: usize) -> *const u8 { + let solver: &Solver = unsafe { &*s }; + if idx < solver.idxs.len() { + solver.idxs[idx] + } else { + 0 as *const u8 + } +} + +static mut WORD_BUF: [u8; 5] = [0u8; 5]; +#[no_mangle] +pub extern fn find_word_load(_s: *const Solver) -> *mut u8 { + unsafe {&mut *(&mut WORD_BUF[0])} +} + +#[no_mangle] +pub extern fn find_word(s: *const Solver) -> isize { + let solver: &Solver = unsafe { &*s }; + let mut idx = -1; + for (i, ptr) in solver.idxs.iter().enumerate() { + let word: &[u8] = unsafe { std::slice::from_raw_parts(*ptr, 5) }; + unsafe { + if word == WORD_BUF { + idx = i as isize; + break; + } + } + } + idx +} + +#[no_mangle] +pub extern fn lookup_match(s: *const Solver, guess_idx: usize, ref_idx: usize) -> u8 { + let solver: &Solver = unsafe { &*s }; + let idx = guess_idx * solver.idxs.len() + ref_idx; + if idx < solver.lookup.len() { + solver.lookup[idx] + } else { + 0xff + } +} + + +#[no_mangle] +pub extern fn get_valid_word(s: *const Solver, idx: usize) -> isize { + let solver: &Solver = unsafe { &*s }; + if idx < solver.valid_words.len() { + solver.valid_words[idx] as isize + } else { + -1 + } +} + + +static mut LAST_ENTROPY: f32 = 0f32; +#[no_mangle] +pub extern fn best_word(s: *mut Solver) -> isize { + let solver: &mut Solver = unsafe { &mut *s }; + let (idx, entropy) = solver.best_word(); + unsafe { LAST_ENTROPY = entropy }; + idx +} + +#[no_mangle] +pub extern fn best_word_entropy(_s: *const Solver) -> f32 { + unsafe { LAST_ENTROPY } +} + #[cfg(test)] mod tests { use super::*; diff --git a/wordle_shim.js b/wordle_shim.js index d21bbcf..185837e 100644 --- a/wordle_shim.js +++ b/wordle_shim.js @@ -1,4 +1,5 @@ var wasm_solver = null +var solver = null function init_wasm_solver(words_str) { const fill_string = (offset) => { @@ -17,6 +18,47 @@ function init_wasm_solver(words_str) { } }).then(wm => { wasm_solver = wm.instance - wasm_solver.exports.init(words_str.length) + const exports = wasm_solver.exports + const solver_ptr = wasm_solver.exports.init(words_str.length) + solver = { + ptr: solver_ptr, + precalc: (nw) => exports.precalc(solver_ptr, nw), + precalc_done: () => exports.precalc_done(solver_ptr), + + reset: () => exports.reset(solver_ptr), + eliminate_words: (guess_idx, guess_result) => + exports.eliminate_words(solver_ptr, guess_idx, guess_result), + words_left: () => exports.words_left(solver_ptr), + calc_entropy: (idx) => exports.calc_entropy(solver_ptr, idx), + find_word: (w) => { + const ary = new TextEncoder().encode(w, "utf8") + const offset = exports.find_word_load(solver_ptr) + const dst = new Uint8Array(exports.memory.buffer, offset, 5) + dst.set(ary) + return exports.find_word(solver_ptr) + }, + best_word: () => { + const idx = exports.best_word(solver_ptr) + const entropy = exports.best_word_entropy(solver_ptr) + var word = null + if (idx != -1) { + const offset = exports.lookup_word(solver_ptr, idx) + word_ary = new Uint8Array(exports.memory.buffer, offset, 5) + word = new TextDecoder().decode(word_ary) + } + return [idx, entropy, word] + }, + + lookup_valid_word: (i) => { + const idx = exports.get_valid_word(solver_ptr, i) + if (idx != -1) { + const offset = exports.lookup_word(solver_ptr, idx) + word_ary = new Uint8Array(exports.memory.buffer, offset, 5) + return new TextDecoder().decode(word_ary) + } else { + return null + } + } + } }) }