Add some more utility functions; basic testing shows it works about the same as the original Javascript, after fixing that off by one bug

This commit is contained in:
Kelvin Ly 2023-05-12 17:05:00 -04:00
parent 93c5821248
commit 1e31f023bd
2 changed files with 301 additions and 178 deletions

View File

@ -1,5 +1,3 @@
use std::alloc::{alloc, dealloc, Layout};
extern crate wee_alloc;
extern {
@ -11,185 +9,13 @@ extern {
#[global_allocator]
static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
const LIST_END: u16 = 65535;
fn unwrap_or_abort<T>(v: Option<T>) -> T {
match v {
Some(v) => v,
None => std::process::abort(),
}
}
unsafe fn alloc_ary<T>(sz: usize) -> &'static mut [T] {
let layout = unwrap_or_abort(Layout::from_size_align(
sz*std::mem::size_of::<T>(),
std::mem::align_of::<T>()
).ok());
let ptr = alloc(layout) as *mut T;
std::slice::from_raw_parts_mut(ptr, sz)
}
unsafe fn dealloc_ary<'a, T>(v: &'a mut [T]) {
let layout = unwrap_or_abort(Layout::from_size_align(
v.len()*std::mem::size_of::<T>(),
std::mem::align_of::<T>()
).ok());
dealloc(v.as_ptr() as *mut u8, layout);
}
// NOTE: you must use this like a = realloc_ary(a, 100)
// otherwise you will suffer from use after free
unsafe fn realloc_ary<'a, T>(v: &'a mut [T], sz: usize) -> &'static mut [T] {
if v.len() == sz {
unsafe { std::slice::from_raw_parts_mut(v.as_mut_ptr(), v.len()) }
} else {
if v.len() > 0 {
unsafe { dealloc_ary(v) };
}
unsafe { alloc_ary(sz) }
}
}
const WORD_SZ: usize = 5;
// webassembly version of all the wordle logic
// because the javascript version is slow as hell on firefox
// TODO wrap all this in a struct to reduce the insane amount of
// unsafe in the code
pub static mut LOOKUP: &'static mut [u8] = &mut [];
pub static mut STRINGS: &'static mut [u8] = &mut [];
pub static mut STR_IDXS: &'static mut [*const u8] = &mut [];
pub static mut VALID_STRS: &'static mut [u16] = &mut [];
// TODO maybe fiddle with the lifetimes eventually
pub struct Solver {
lookup: &'static mut [u8],
strings: &'static mut [u8],
idxs: &'static mut [*const u8],
valid_words: &'static mut [u16],
}
impl Solver {
fn new() -> Solver {
Solver {
strings: &mut [],
idxs: &mut [],
valid_words: &mut [],
lookup: &mut [],
}
}
// NOTE: calls fill_string to get the string from Javascript
fn init(&mut self, str_sz: usize) {
unsafe {
self.strings = realloc_ary(self.strings, str_sz);
fill_string(self.strings.as_mut_ptr());
self.init_index();
log_num_idxs(self.idxs.len());
self.reset();
}
}
fn init_index(&mut self) {
let mut last_alpha = false;
let mut num_words = 0;
for v in self.strings.iter() {
let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8);
if cur_alpha && !last_alpha {
num_words += 1;
}
last_alpha = cur_alpha;
}
unsafe {
self.idxs = realloc_ary(self.idxs, num_words);
}
let mut idx = 0;
for (i, v) in self.strings.iter().enumerate() {
let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8);
if cur_alpha && !last_alpha {
self.idxs[idx] = &self.strings[i] as *const u8;
idx += 1;
}
last_alpha = cur_alpha;
}
}
fn reset(&mut self) {
unsafe {
self.valid_words = realloc_ary(self.valid_words, self.idxs.len());
}
for (i, v) in self.valid_words.iter_mut().enumerate() {
*v = i as u16;
}
}
fn count_valid_words(&self) -> usize {
self.valid_words.iter().take_while(|&v| *v != LIST_END).count()
}
fn eliminate_words(&mut self, guess_idx: usize, guess_result: u8) {
let num_valid_words = self.count_valid_words();
let mut idx = 0;
for i in 0..num_valid_words {
if self.lookup[guess_idx * self.idxs.len() + i] == guess_result {
self.valid_words[idx] = self.valid_words[i];
}
}
if idx != self.valid_words.len() {
self.valid_words[idx] = LIST_END;
}
}
}
#[no_mangle]
pub fn idxs_offset(s: &Solver) -> *const *const u8 {
s.idxs.as_ptr()
}
#[no_mangle]
pub extern fn init(str_sz: usize) -> *mut Solver {
let s = Box::new(Solver::new());
s.init(str_sz);
let ret = unsafe { s.as_ptr_mut() };
std::mem::forget(s);
ret
}
#[no_mangle]
fn precalc(num_words: usize) {
let mut i = unsafe { IDX };
let idxs = unsafe { &*STR_IDXS };
let lookup = unsafe { &mut *LOOKUP };
let l = idxs.len();
for _ in 0..num_words {
if i >= idxs.len() {
break;
}
let guess = idxs[i];
for (j, reference) in idxs.iter().enumerate() {
lookup[i*l + j] = calc_match(guess, *reference);
}
i += 1;
}
unsafe {
IDX = i;
}
}
#[no_mangle]
fn precalc_done() -> bool {
unsafe {
IDX >= STR_IDXS.len()
}
}
#[no_mangle]
fn calc_match(guess: *const u8, reference: *const u8) -> u8 {
let mut unmatched = vec![0u8; 26];
pub extern fn calc_match(guess: *const u8, reference: *const u8, unmatched: &mut [u8; 26]) -> u8 {
unmatched.fill(0);
let mut ret = 0;
let mut matched = 0;
@ -221,6 +47,261 @@ fn calc_match(guess: *const u8, reference: *const u8) -> u8 {
ret
}
pub struct Solver {
strings: Vec<u8>,
idxs: Vec<*const u8>,
valid_words: Vec<u16>,
lookup: Vec<u8>,
precalc_idx: usize,
entropy_scratch: [u16; 3*3*3*3*3],
}
impl Solver {
fn new() -> Solver {
Solver {
strings: Vec::new(),
idxs: Vec::new(),
valid_words: Vec::new(),
lookup: Vec::new(),
precalc_idx: 0,
entropy_scratch: [0u16; 3*3*3*3*3],
}
}
// NOTE: calls fill_string to get the string from Javascript
fn init(&mut self, str_sz: usize) {
self.strings.resize(str_sz, 0u8);
unsafe { fill_string(self.strings.as_mut_ptr()); }
self.init_index();
let num_words = self.idxs.len();
unsafe { log_num_idxs(num_words); }
self.reset();
self.lookup.resize(num_words*num_words, 0);
}
fn init_index(&mut self) {
let mut last_alpha = false;
let mut num_words = 0;
for v in self.strings.iter() {
let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8);
if cur_alpha && !last_alpha {
num_words += 1;
}
last_alpha = cur_alpha;
}
self.idxs.resize(num_words, 0 as *const u8);
let mut idx = 0;
for (i, v) in self.strings.iter().enumerate() {
let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8);
if cur_alpha && !last_alpha {
self.idxs[idx] = &self.strings[i] as *const u8;
idx += 1;
}
last_alpha = cur_alpha;
}
}
fn reset(&mut self) {
self.valid_words.resize(self.idxs.len(), 65535);
for (i, v) in self.valid_words.iter_mut().enumerate() {
*v = i as u16;
}
}
fn eliminate_words(&mut self, guess_idx: usize, guess_result: u8) {
let num_valid_words = self.valid_words.len();
let mut idx = 0;
for i in 0..num_valid_words {
if self.lookup[guess_idx * self.idxs.len() + self.valid_words[i] as usize] == guess_result {
self.valid_words[idx] = self.valid_words[i];
idx += 1;
}
}
self.valid_words.resize(idx, 65535);
}
fn precalc(&mut self, num_words: usize) -> bool {
let mut tmp = [0u8; 26];
let len = self.idxs.len();
for _ in 0..num_words {
if self.precalc_idx >= len {
break;
}
let i = self.precalc_idx;
for (j, v) in self.idxs.iter().enumerate() {
self.lookup[i * len + j] = calc_match(self.idxs[i], *v, &mut tmp);
}
self.precalc_idx += 1;
}
self.precalc_idx >= len
}
fn calc_entropy(&mut self, guess_idx: usize) -> f32 {
self.entropy_scratch.fill(0);
let l = self.idxs.len();
for idx in &self.valid_words {
self.entropy_scratch[self.lookup[guess_idx*l + (*idx as usize)] as usize] += 1;
}
let mut entropy = 0f32;
let recip = 1.0f32/(self.valid_words.len() as f32);
for v in &self.entropy_scratch {
if *v != 0 {
let prob = (*v as f32)*recip;
entropy -= prob.log2() * prob;
}
}
entropy
}
fn best_word(&mut self) -> (isize, f32) {
if self.valid_words.len() == 1 {
return (self.valid_words[0] as isize, 0f32);
}
let mut cur_best = -1;
let mut cur_entropy = 0f32;
for i in 0..self.idxs.len() {
let entropy = self.calc_entropy(i);
if entropy > cur_entropy {
cur_entropy = entropy;
cur_best = i as isize;
}
}
(cur_best, cur_entropy)
}
}
/*
#[no_mangle]
pub fn idxs_offset(s: &Solver) -> *const *const u8 {
(&s.idxs[0]).as_ptr()
}
*/
#[no_mangle]
pub extern fn init(str_sz: usize) -> *mut Solver {
let mut s = Box::new(Solver::new());
s.init(str_sz);
Box::into_raw(s)
}
#[no_mangle]
pub extern fn reset(s: *mut Solver) {
let solver: &mut Solver = unsafe { &mut *s };
solver.reset();
}
#[no_mangle]
pub extern fn eliminate_words(s: *mut Solver, guess_idx: usize, guess_result: u8) {
let solver: &mut Solver = unsafe { &mut *s };
solver.eliminate_words(guess_idx, guess_result);
}
#[no_mangle]
pub extern fn words_left(s: *const Solver) -> usize {
let solver: &Solver = unsafe { &*s };
solver.valid_words.len()
}
#[no_mangle]
pub extern fn calc_entropy(s: *mut Solver, idx: usize) -> f32{
let solver: &mut Solver = unsafe { &mut *s };
solver.calc_entropy(idx)
}
#[no_mangle]
pub extern fn precalc(s: *mut Solver, num_words: usize) -> bool {
let solver: &mut Solver = unsafe { &mut *s };
solver.precalc(num_words)
}
#[no_mangle]
pub extern fn get_precalc(s: *const Solver) -> usize {
let solver: &Solver = unsafe { &*s };
solver.precalc_idx
}
#[no_mangle]
pub extern fn precalc_done(s: *const Solver) -> bool {
let solver: &Solver = unsafe { &*s };
solver.precalc_idx >= solver.idxs.len()
}
#[no_mangle]
pub extern fn lookup_word(s: *const Solver, idx: usize) -> *const u8 {
let solver: &Solver = unsafe { &*s };
if idx < solver.idxs.len() {
solver.idxs[idx]
} else {
0 as *const u8
}
}
static mut WORD_BUF: [u8; 5] = [0u8; 5];
#[no_mangle]
pub extern fn find_word_load(_s: *const Solver) -> *mut u8 {
unsafe {&mut *(&mut WORD_BUF[0])}
}
#[no_mangle]
pub extern fn find_word(s: *const Solver) -> isize {
let solver: &Solver = unsafe { &*s };
let mut idx = -1;
for (i, ptr) in solver.idxs.iter().enumerate() {
let word: &[u8] = unsafe { std::slice::from_raw_parts(*ptr, 5) };
unsafe {
if word == WORD_BUF {
idx = i as isize;
break;
}
}
}
idx
}
#[no_mangle]
pub extern fn lookup_match(s: *const Solver, guess_idx: usize, ref_idx: usize) -> u8 {
let solver: &Solver = unsafe { &*s };
let idx = guess_idx * solver.idxs.len() + ref_idx;
if idx < solver.lookup.len() {
solver.lookup[idx]
} else {
0xff
}
}
#[no_mangle]
pub extern fn get_valid_word(s: *const Solver, idx: usize) -> isize {
let solver: &Solver = unsafe { &*s };
if idx < solver.valid_words.len() {
solver.valid_words[idx] as isize
} else {
-1
}
}
static mut LAST_ENTROPY: f32 = 0f32;
#[no_mangle]
pub extern fn best_word(s: *mut Solver) -> isize {
let solver: &mut Solver = unsafe { &mut *s };
let (idx, entropy) = solver.best_word();
unsafe { LAST_ENTROPY = entropy };
idx
}
#[no_mangle]
pub extern fn best_word_entropy(_s: *const Solver) -> f32 {
unsafe { LAST_ENTROPY }
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,4 +1,5 @@
var wasm_solver = null
var solver = null
function init_wasm_solver(words_str) {
const fill_string = (offset) => {
@ -17,6 +18,47 @@ function init_wasm_solver(words_str) {
}
}).then(wm => {
wasm_solver = wm.instance
wasm_solver.exports.init(words_str.length)
const exports = wasm_solver.exports
const solver_ptr = wasm_solver.exports.init(words_str.length)
solver = {
ptr: solver_ptr,
precalc: (nw) => exports.precalc(solver_ptr, nw),
precalc_done: () => exports.precalc_done(solver_ptr),
reset: () => exports.reset(solver_ptr),
eliminate_words: (guess_idx, guess_result) =>
exports.eliminate_words(solver_ptr, guess_idx, guess_result),
words_left: () => exports.words_left(solver_ptr),
calc_entropy: (idx) => exports.calc_entropy(solver_ptr, idx),
find_word: (w) => {
const ary = new TextEncoder().encode(w, "utf8")
const offset = exports.find_word_load(solver_ptr)
const dst = new Uint8Array(exports.memory.buffer, offset, 5)
dst.set(ary)
return exports.find_word(solver_ptr)
},
best_word: () => {
const idx = exports.best_word(solver_ptr)
const entropy = exports.best_word_entropy(solver_ptr)
var word = null
if (idx != -1) {
const offset = exports.lookup_word(solver_ptr, idx)
word_ary = new Uint8Array(exports.memory.buffer, offset, 5)
word = new TextDecoder().decode(word_ary)
}
return [idx, entropy, word]
},
lookup_valid_word: (i) => {
const idx = exports.get_valid_word(solver_ptr, i)
if (idx != -1) {
const offset = exports.lookup_word(solver_ptr, idx)
word_ary = new Uint8Array(exports.memory.buffer, offset, 5)
return new TextDecoder().decode(word_ary)
} else {
return null
}
}
}
})
}