Wrap everything in a struct

This commit is contained in:
Kelvin Ly 2023-05-12 15:26:13 -04:00
parent 74812a6a01
commit 93c5821248
1 changed files with 115 additions and 70 deletions

View File

@ -2,25 +2,16 @@ use std::alloc::{alloc, dealloc, Layout};
extern crate wee_alloc;
extern {
fn fill_string(p: *mut u8);
fn log_num_idxs(i: usize);
}
#[global_allocator]
static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
// webassembly version of all the wordle logic
// because the javascript version is slow as hell on firefox
// TODO wrap all this in a struct to reduce the insane amount of
// unsafe in the code
pub static mut LOOKUP: &'static mut [u8] = &mut [];
pub static mut STRINGS: &'static mut [u8] = &mut [];
pub static mut STR_IDXS: &'static mut [*const u8] = &mut [];
pub static mut VALID_STRS: &'static mut [usize] = &mut [];
#[no_mangle]
pub fn idxs_offset() -> *const *const u8 {
unsafe {
STR_IDXS.as_ptr()
}
}
const LIST_END: u16 = 65535;
fn unwrap_or_abort<T>(v: Option<T>) -> T {
match v {
@ -38,7 +29,7 @@ unsafe fn alloc_ary<T>(sz: usize) -> &'static mut [T] {
std::slice::from_raw_parts_mut(ptr, sz)
}
unsafe fn dealloc_ary<T>(v: &'static mut [T]) {
unsafe fn dealloc_ary<'a, T>(v: &'a mut [T]) {
let layout = unwrap_or_abort(Layout::from_size_align(
v.len()*std::mem::size_of::<T>(),
std::mem::align_of::<T>()
@ -46,72 +37,126 @@ unsafe fn dealloc_ary<T>(v: &'static mut [T]) {
dealloc(v.as_ptr() as *mut u8, layout);
}
fn realloc_ary<T>(v: &'static mut [T], sz: usize) -> &'static mut [T] {
unsafe {
// NOTE: you must use this like a = realloc_ary(a, 100)
// otherwise you will suffer from use after free
unsafe fn realloc_ary<'a, T>(v: &'a mut [T], sz: usize) -> &'static mut [T] {
if v.len() == sz {
unsafe { std::slice::from_raw_parts_mut(v.as_mut_ptr(), v.len()) }
} else {
if v.len() > 0 {
dealloc_ary(v);
unsafe { dealloc_ary(v) };
}
alloc_ary(sz)
unsafe { alloc_ary(sz) }
}
}
// webassembly version of all the wordle logic
// because the javascript version is slow as hell on firefox
extern {
fn fill_string(p: *mut u8);
fn log_num_idxs(i: usize);
// TODO wrap all this in a struct to reduce the insane amount of
// unsafe in the code
pub static mut LOOKUP: &'static mut [u8] = &mut [];
pub static mut STRINGS: &'static mut [u8] = &mut [];
pub static mut STR_IDXS: &'static mut [*const u8] = &mut [];
pub static mut VALID_STRS: &'static mut [u16] = &mut [];
// TODO maybe fiddle with the lifetimes eventually
pub struct Solver {
lookup: &'static mut [u8],
strings: &'static mut [u8],
idxs: &'static mut [*const u8],
valid_words: &'static mut [u16],
}
impl Solver {
fn new() -> Solver {
Solver {
strings: &mut [],
idxs: &mut [],
valid_words: &mut [],
lookup: &mut [],
}
}
// NOTE: calls fill_string to get the string from Javascript
fn init(&mut self, str_sz: usize) {
unsafe {
self.strings = realloc_ary(self.strings, str_sz);
fill_string(self.strings.as_mut_ptr());
self.init_index();
log_num_idxs(self.idxs.len());
self.reset();
}
}
fn init_index(&mut self) {
let mut last_alpha = false;
let mut num_words = 0;
for v in self.strings.iter() {
let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8);
if cur_alpha && !last_alpha {
num_words += 1;
}
last_alpha = cur_alpha;
}
unsafe {
self.idxs = realloc_ary(self.idxs, num_words);
}
let mut idx = 0;
for (i, v) in self.strings.iter().enumerate() {
let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8);
if cur_alpha && !last_alpha {
self.idxs[idx] = &self.strings[i] as *const u8;
idx += 1;
}
last_alpha = cur_alpha;
}
}
fn reset(&mut self) {
unsafe {
self.valid_words = realloc_ary(self.valid_words, self.idxs.len());
}
for (i, v) in self.valid_words.iter_mut().enumerate() {
*v = i as u16;
}
}
fn count_valid_words(&self) -> usize {
self.valid_words.iter().take_while(|&v| *v != LIST_END).count()
}
fn eliminate_words(&mut self, guess_idx: usize, guess_result: u8) {
let num_valid_words = self.count_valid_words();
let mut idx = 0;
for i in 0..num_valid_words {
if self.lookup[guess_idx * self.idxs.len() + i] == guess_result {
self.valid_words[idx] = self.valid_words[i];
}
}
if idx != self.valid_words.len() {
self.valid_words[idx] = LIST_END;
}
}
}
#[no_mangle]
pub extern fn init(str_sz: usize) {
unsafe {
STRINGS = realloc_ary(STRINGS, str_sz);
fill_string(STRINGS.as_mut_ptr());
}
let num_strs = init_idx();
unsafe {
log_num_idxs(num_strs);
let num_strings = STR_IDXS.len();
LOOKUP = realloc_ary(LOOKUP, num_strings*num_strings);
}
pub fn idxs_offset(s: &Solver) -> *const *const u8 {
s.idxs.as_ptr()
}
fn init_idx() -> usize {
let strings = unsafe { &*STRINGS };
let mut last_alpha = false;
let mut num_words = 0;
for v in strings.iter() {
let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8);
if cur_alpha && !last_alpha {
num_words += 1;
}
last_alpha = cur_alpha;
}
unsafe {
if STR_IDXS.len() > 0 {
dealloc_ary(STR_IDXS);
}
STR_IDXS = alloc_ary(num_words);
}
let idxs = unsafe { &mut *STR_IDXS };
let mut idx = 0;
for (i, v) in strings.iter().enumerate() {
let cur_alpha = *v >= ('a' as u8) && *v <= ('z' as u8);
if cur_alpha && !last_alpha {
idxs[idx] = &strings[i] as *const u8;
idx += 1;
}
last_alpha = cur_alpha;
}
num_words
#[no_mangle]
pub extern fn init(str_sz: usize) -> *mut Solver {
let s = Box::new(Solver::new());
s.init(str_sz);
let ret = unsafe { s.as_ptr_mut() };
std::mem::forget(s);
ret
}
static mut IDX: usize = 0;
#[no_mangle]
fn precalc(num_words: usize) {
let mut i = unsafe { IDX };