From 92e4cfc123a1d26265128634850a2b73bac761c2 Mon Sep 17 00:00:00 2001 From: wires Date: Fri, 24 Oct 2025 11:29:46 -0400 Subject: first draft of sqlite wrapper --- wyrd_sqlite/src/lib.rs | 436 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 436 insertions(+) create mode 100644 wyrd_sqlite/src/lib.rs (limited to 'wyrd_sqlite/src/lib.rs') diff --git a/wyrd_sqlite/src/lib.rs b/wyrd_sqlite/src/lib.rs new file mode 100644 index 0000000..2ef498a --- /dev/null +++ b/wyrd_sqlite/src/lib.rs @@ -0,0 +1,436 @@ +//! thin wrapper around the bits of sqlite we need. why not just use +//! [rusqlite](https://docs.rs/rusqlite) or similar? i ran into a minor annoyance with their API +//! and decided it would be fun to reinvent the wheel a bit and see how i could do. that said, this +//! implementation owes a great deal to theirs. +#![deny(clippy::undocumented_unsafe_blocks)] +#![deny(clippy::missing_safety_doc)] +#![deny(unsafe_op_in_unsafe_fn)] + +use std::{ + cell::RefCell, + ffi::{CString, c_char, c_int, c_uint}, + marker::PhantomData, + mem::ManuallyDrop, + ops::{Deref, DerefMut}, + ptr, slice, + str::FromStr, +}; + +use libsqlite3_sys as ffi; + +mod error; +mod from_sql; +mod params; + +pub use error::{Error, GetError, GetResult, Result}; +pub use from_sql::{FromSql, Value}; +pub use params::{Param, Params}; + +fn version_number() -> i32 { + // SAFETY: trivial wrapper + unsafe { ffi::sqlite3_libversion_number() } +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug)] + #[repr(C)] + pub struct OpenFlags: c_int { + const READONLY = ffi::SQLITE_OPEN_READONLY; + const READWRITE = ffi::SQLITE_OPEN_READWRITE; + const CREATE = ffi::SQLITE_OPEN_CREATE; + const URI = ffi::SQLITE_OPEN_URI; + const MEMORY = ffi::SQLITE_OPEN_MEMORY; + const NOMUTEX = ffi::SQLITE_OPEN_NOMUTEX; + const NOFOLLOW = ffi::SQLITE_OPEN_NOFOLLOW; + const EXRESCODE = ffi::SQLITE_OPEN_EXRESCODE; + } +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, Default)] + #[repr(C)] + pub struct PrepFlags: c_uint { + const PERSISTENT = ffi::SQLITE_PREPARE_PERSISTENT; + const NO_VTAB = ffi::SQLITE_PREPARE_NO_VTAB; + const DONT_LOG = ffi::SQLITE_PREPARE_DONT_LOG; + } +} + +impl Default for OpenFlags { + fn default() -> Self { + if version_number() > 3_037_000 { + Self::READWRITE | Self::CREATE | Self::NOMUTEX | Self::URI | Self::EXRESCODE + } else { + Self::READWRITE | Self::CREATE | Self::NOMUTEX | Self::URI + } + } +} + +pub struct Connection(RefCell); + +impl Connection { + pub fn open(filename: &str) -> Result { + Self::open_with_flags(filename, OpenFlags::default()) + } + + pub fn open_with_flags(filename: &str, flags: OpenFlags) -> Result { + ConnectionInner::open(filename, flags) + .map(RefCell::new) + .map(Self) + } + + pub fn prepare(&self, src: &str) -> Result<(Statement<'_>, usize)> { + self.prepare_with_flags(src, PrepFlags::default()) + } + + pub fn prepare_with_flags( + &self, + src: &str, + flags: PrepFlags, + ) -> Result<(Statement<'_>, usize)> { + self.prepare_raw(src, flags) + .map(|(raw, read)| (Statement { raw, conn: self }, read)) + } + + pub fn prepare_raw(&self, src: &str, flags: PrepFlags) -> Result<(RawStatement, usize)> { + self.0.borrow_mut().prepare_raw(src, flags) + } + + fn get_error(&self) -> Error { + self.0.borrow_mut().get_error() + } + + fn decode_response(&self, code: c_int) -> Result<()> { + self.0.borrow_mut().decode_response(code) + } + + fn changes(&self) -> usize { + self.0.borrow_mut().changes() + } + + pub fn execute(&self, src: &str, params: P) -> Result { + let (mut stmt, read) = self.prepare(src)?; + if read != src.len() { + Err(Error::MultipleStatements) + } else { + stmt.execute(params) + } + } +} + +struct ConnectionInner { + db: *mut ffi::sqlite3, +} + +// SAFETY: ConnectionInner owns the underlying pointer, and all its methods are safe to call from +// different threads, just not more than one at a time. +unsafe impl Send for ConnectionInner {} + +impl ConnectionInner { + // we take &str and not &Path because 1. sqlite specifies that the argument should be + // valid utf-8 and 2. there are valid inputs that aren't actually paths + fn open(filename: &str, flags: OpenFlags) -> Result { + let c_filename = CString::from_str(filename)?; + + let mut db: *mut ffi::sqlite3 = std::ptr::null_mut(); + let r = + // SAFETY: we're mutating the pointer db, not dereferencing it, so we're good + unsafe { ffi::sqlite3_open_v2(c_filename.as_ptr(), &mut db, flags.bits(), std::ptr::null()) }; + + if r == ffi::SQLITE_OK { + Ok(Self { db }) + } else if db.is_null() { + Err(Error::from_code(r)) + } else { + let e = Error::from_db(db); + // SAFETY: db came from sqlite3_open_v2 and it's not null so we're good to close it + let r = unsafe { ffi::sqlite3_close(db) }; + debug_assert_eq!(r, ffi::SQLITE_OK); + Err(e) + } + } + + fn get_error(&mut self) -> Error { + Error::from_db(self.db) + } + + fn decode_response(&mut self, code: c_int) -> Result<()> { + if code == ffi::SQLITE_OK { + Ok(()) + } else { + Err(self.get_error()) + } + } + + fn changes(&mut self) -> usize { + // SAFETY: this is only ever called immediately after executing a statement, and we're only + // using the database connection from a single thread at a time, so the value should always + // be good. + (unsafe { ffi::sqlite3_changes(self.db) }) as usize + } + + fn prepare_raw(&mut self, src: &str, flags: PrepFlags) -> Result<(RawStatement, usize)> { + let mut stmt: *mut ffi::sqlite3_stmt = ptr::null_mut(); + let mut tail: *const c_char = ptr::null(); + + // SAFETY: we know self.db hasn't been closed because we only close when we drop, and we + // know &mut c_stmt isn't null because it's pointing to a thing we own. this upholds + // sqlite's requirements + self.decode_response(unsafe { + ffi::sqlite3_prepare_v3( + self.db, + src.as_ptr().cast(), + src.len() as c_int, + flags.bits(), + &mut stmt, + &mut tail, + ) + })?; + + if stmt.is_null() { + Err(Error::EmptyStatement) + } else { + // SAFETY: sqlite guarantees that tail will point into src, so we know its address is >= + let read = unsafe { tail.offset_from_unsigned(src.as_ptr().cast()) }; + Ok((RawStatement::new(stmt), read)) + } + } +} + +impl Drop for ConnectionInner { + fn drop(&mut self) { + // SAFETY: Connection always owns the underlying db, so can close it + let r = unsafe { ffi::sqlite3_close(self.db) }; + debug_assert_eq!(r, ffi::SQLITE_OK); + } +} + +pub struct RawStatement { + ptr: *mut ffi::sqlite3_stmt, +} + +impl Drop for RawStatement { + fn drop(&mut self) { + // SAFETY: RawStatement owns its pointer so this should never double free + unsafe { ffi::sqlite3_finalize(self.ptr) }; + } +} + +impl RawStatement { + fn new(ptr: *mut ffi::sqlite3_stmt) -> Self { + Self { ptr } + } + + /// # Safety + /// + /// the caller must ensure that `conn` is the database connection that was used to prepare the + /// statement. this isn't unsafe in a strict sense, since the outcome of failing to uphold this + /// invariant is just `MISUSE` errors for sqlite but still, don't do it! + pub unsafe fn with_conn<'a>(&'a mut self, conn: &'a Connection) -> BorrowedStatement<'a> { + // copying a type w/ a destructor is bad news, but we're immediately wrapping it to make + // sure the destructor won't be called twice. + let inner = ManuallyDrop::new(Statement { + raw: Self { ptr: self.ptr }, + conn, + }); + BorrowedStatement { + marker: PhantomData, + inner, + } + } +} + +pub struct Statement<'a> { + raw: RawStatement, + conn: &'a Connection, +} + +// lot of little wrapper functions that we know are safe because holding a borrow guarantees +// we're on the same thread as our database connection, and they don't have any safety requirements +// beyond that. +#[allow(clippy::undocumented_unsafe_blocks)] +impl<'a> Statement<'a> { + fn ptr(&self) -> *mut ffi::sqlite3_stmt { + self.raw.ptr + } + + fn step(&mut self) -> c_int { + unsafe { ffi::sqlite3_step(self.ptr()) } + } + + fn column_type(&mut self, i: c_int) -> Result { + let count = unsafe { ffi::sqlite3_column_count(self.ptr()) }; + if i >= count || i < 0 { + Err(Error::InvalidColumn(i)) + } else { + Ok(unsafe { ffi::sqlite3_column_type(self.ptr(), i) }) + } + } + + fn reset_inner(&mut self) -> c_int { + unsafe { ffi::sqlite3_reset(self.ptr()) } + } + + fn reset(&mut self) -> Result<()> { + self.conn.decode_response(self.reset_inner()) + } + + fn bind_i32(&mut self, i: c_int, n: c_int) -> Result<()> { + let r = unsafe { ffi::sqlite3_bind_int(self.ptr(), i, n) }; + self.conn.decode_response(r) + } + + fn bind_i64(&mut self, i: c_int, n: i64) -> Result<()> { + let r = unsafe { ffi::sqlite3_bind_int64(self.ptr(), i, n) }; + self.conn.decode_response(r) + } + + fn bind_double(&mut self, i: c_int, f: f64) -> Result<()> { + let r = unsafe { ffi::sqlite3_bind_double(self.ptr(), i, f) }; + self.conn.decode_response(r) + } + + fn bind_null(&mut self, i: c_int) -> Result<()> { + let r = unsafe { ffi::sqlite3_bind_null(self.ptr(), i) }; + self.conn.decode_response(r) + } + + pub fn query<'q, P: Params>(&'q mut self, params: P) -> Result> { + params.__bind_in(self)?; + Ok(Rows::new(self)) + } + + pub fn execute(&mut self, params: P) -> Result { + params.__bind_in(self)?; + match self.step() { + ffi::SQLITE_DONE => { + self.reset()?; + Ok(self.conn.changes()) + } + ffi::SQLITE_ROW => { + self.reset_inner(); + Err(Error::ExecReturnedRows) + } + _ => Err(self.conn.get_error()), + } + } +} + +pub struct BorrowedStatement<'conn> { + marker: PhantomData<&'conn mut RawStatement>, + // using the raw pointer contained in here as a secret &'a mut RawStatement so we can make Deref + // happy but that means we REALLY DONT WANNA DROP THIS bc we don't *actually* own the + // underlying pointer + inner: ManuallyDrop>, +} + +impl<'conn> Deref for BorrowedStatement<'conn> { + type Target = Statement<'conn>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for BorrowedStatement<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +pub struct Rows<'conn, 'stmt> { + stmt: &'stmt mut Statement<'conn>, + // hate that this is necessary. should be able to get away with making stmt optional but get + // stuck in borrow checker hell + finished: bool, +} + +impl<'a, 'q> Rows<'a, 'q> { + fn new(stmt: &'q mut Statement<'a>) -> Self { + Self { + stmt, + finished: false, + } + } + + pub fn try_next_row<'row>(&'row mut self) -> Result>> { + if self.finished { + return Ok(None); + } + + match self.stmt.step() { + ffi::SQLITE_ROW => Ok(Some(Row { stmt: self.stmt })), + ffi::SQLITE_DONE => { + self.finished = true; + self.stmt.reset()?; + Ok(None) + } + _ => { + self.stmt.reset_inner(); + Err(self.stmt.conn.get_error()) + } + } + } +} + +pub struct Row<'conn, 'row> { + stmt: &'row mut Statement<'conn>, +} + +impl<'row> Row<'_, 'row> { + // should potentially add a "raw" variant of this that lets the caller just accept the sqlite + // dynamic conversions + fn get_value(&mut self, i: c_int) -> Result> { + let ptr = self.stmt.ptr(); + + Ok(match self.stmt.column_type(i)? { + ffi::SQLITE_INTEGER => Value::Int( + // SAFETY: we've verified the type, ptr is valid, etc. + unsafe { ffi::sqlite3_column_int64(ptr, i) }, + ), + ffi::SQLITE_FLOAT => Value::Float( + // SAFETY: same as above + unsafe { ffi::sqlite3_column_double(ptr, i) }, + ), + ffi::SQLITE_TEXT => { + // SAFETY: same as above + let data = unsafe { ffi::sqlite3_column_text(ptr, i) }; + if data.is_null() { + Value::Null + } else { + // SAFETY: going through the criteria for from_raw_parts we have + // 1. we know data is non-null, and sqlite has assured us it's valid for len + // bytes + // 2. utf-8 strings have no alignment requirements so that's fine + // 3. our lifetime restrictions should prevent the statement from being stepped + // again before this reference goes out of scope, so it shouldn't be mutated + Value::Text(str::from_utf8(unsafe { + let len = ffi::sqlite3_column_bytes(ptr, i) as usize; + slice::from_raw_parts(data, len) + })?) + } + } + ffi::SQLITE_BLOB => { + // SAFETY: same as above + let data: *const u8 = unsafe { ffi::sqlite3_column_blob(ptr, i) }.cast(); + if data.is_null() { + Value::Null + } else { + // SAFETY: same as above + Value::Blob(unsafe { + let len = ffi::sqlite3_column_bytes(ptr, i) as usize; + slice::from_raw_parts(data, len) + }) + } + } + ffi::SQLITE_NULL => Value::Null, + _ => unreachable!(), + }) + } + + pub fn get(&mut self, col: c_int) -> GetResult { + let val = self.get_value(col)?; + + T::try_from_sql(val).map_err(GetError::FromSql) + } +} -- cgit 1.4.1