diff options
Diffstat (limited to '')
| -rw-r--r-- | wyrd_sqlite/src/error.rs | 84 | ||||
| -rw-r--r-- | wyrd_sqlite/src/from_sql.rs | 85 | ||||
| -rw-r--r-- | wyrd_sqlite/src/lib.rs | 436 | ||||
| -rw-r--r-- | wyrd_sqlite/src/params.rs | 92 |
4 files changed, 697 insertions, 0 deletions
diff --git a/wyrd_sqlite/src/error.rs b/wyrd_sqlite/src/error.rs new file mode 100644 index 0000000..281862f --- /dev/null +++ b/wyrd_sqlite/src/error.rs @@ -0,0 +1,84 @@ +use std::{ + ffi::{CStr, NulError, c_int}, + fmt::{self, Display, Formatter}, + str::Utf8Error, +}; + +use crate::{FromSql, ffi}; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + Sqlite(#[from] SqliteError), + #[error("input contained no SQL")] + EmptyStatement, + #[error(transparent)] + Nul(#[from] NulError), + #[error(transparent)] + Utf8(#[from] Utf8Error), + #[error("invalid column index {0}")] + InvalidColumn(c_int), + #[error("execute returned results")] + ExecReturnedRows, + #[error("multiple statements provided")] + MultipleStatements, +} + +#[derive(Debug)] +pub struct SqliteError { + code: c_int, + msg: Option<String>, +} + +impl Error { + pub(crate) fn from_code(code: c_int) -> Self { + SqliteError { code, msg: None }.into() + } + + pub(crate) fn from_db(db: *mut ffi::sqlite3) -> Self { + // SAFETY: sqlite has checks to handle if db is null or dangling, so these shouldn't cause + // ub for any input + let (code, c_msg) = unsafe { (ffi::sqlite3_errcode(db), ffi::sqlite3_errmsg(db)) }; + + let msg = if c_msg.is_null() { + None + } else { + Some( + // SAFETY: as long as c_msg is non-null, sqlite shouldn't be giving us bad strings + unsafe { CStr::from_ptr(c_msg) } + .to_string_lossy() + .to_string(), + ) + }; + + SqliteError { code, msg }.into() + } +} + +fn errstr(code: c_int) -> &'static str { + // SAFETY: `sqlite3_errstr` always returns a valid null-terminated static string + unsafe { CStr::from_ptr(ffi::sqlite3_errstr(code)) } + .to_str() + .expect("sqlite errors should be valid utf8") +} + +impl Display for SqliteError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let msg = self.msg.as_deref().unwrap_or(errstr(self.code)); + write!(f, "{msg} ({})", self.code) + } +} + +impl std::error::Error for SqliteError {} + +pub type Result<T> = std::result::Result<T, Error>; + +#[derive(Debug, thiserror::Error)] +pub enum GetError<T: FromSql> { + #[error(transparent)] + Sqlite(#[from] Error), + #[error(transparent)] + FromSql(T::Error), +} + +pub type GetResult<T> = std::result::Result<T, GetError<T>>; diff --git a/wyrd_sqlite/src/from_sql.rs b/wyrd_sqlite/src/from_sql.rs new file mode 100644 index 0000000..3b149e8 --- /dev/null +++ b/wyrd_sqlite/src/from_sql.rs @@ -0,0 +1,85 @@ +use std::num::{NonZero, TryFromIntError}; + +pub enum Value<'a> { + Int(i64), + Float(f64), + Text(&'a str), + Blob(&'a [u8]), + Null, +} + +// this is an owned trait to get around some problems with blanket implementations +pub trait FromSql: Sized { + type Error: std::error::Error; + + fn try_from_sql(value: Value<'_>) -> Result<Self, Self::Error>; +} + +#[derive(Debug, thiserror::Error)] +#[error("invalid type")] +pub struct InvalidTypeError; + +#[derive(Debug, thiserror::Error)] +pub enum FromSqlIntError { + #[error(transparent)] + InvalidType(#[from] InvalidTypeError), + #[error(transparent)] + TryFromInt(#[from] TryFromIntError), +} + +impl FromSql for i64 { + type Error = FromSqlIntError; + + fn try_from_sql(value: Value<'_>) -> Result<Self, Self::Error> { + if let Value::Int(i) = value { + Ok(i) + } else { + Err(InvalidTypeError.into()) + } + } +} + +impl FromSql for NonZero<i64> { + type Error = FromSqlIntError; + + fn try_from_sql(value: Value) -> Result<Self, Self::Error> { + let i: i64 = FromSql::try_from_sql(value)?; + Ok(i.try_into()?) + } +} + +macro_rules! int_impl { + ($($source:ty),+) => {$( + impl FromSql for $source { + type Error = FromSqlIntError; + + fn try_from_sql(value: Value<'_>) -> Result<Self, Self::Error> { + let i: i64 = FromSql::try_from_sql(value)?; + Ok(i.try_into()?) + } + } + + impl FromSql for NonZero<$source> { + type Error = FromSqlIntError; + + fn try_from_sql(value: Value<'_>) -> Result<Self, Self::Error> { + let i: NonZero<i64> = FromSql::try_from_sql(value)?; + Ok(i.try_into()?) + } + } + )*} +} + +int_impl!(i8, i16, i32, u8, u16, u32, u64); + +impl<'a, T: FromSql> FromSql for Option<T> { + type Error = T::Error; + + fn try_from_sql(value: Value<'_>) -> Result<Self, Self::Error> { + if let Value::Null = value { + Ok(None) + } else { + FromSql::try_from_sql(value).map(Some) + } + } +} diff --git a/wyrd_sqlite/src/lib.rs b/wyrd_sqlite/src/lib.rs new file mode 100644 index 0000000..2ef498a --- /dev/null +++ b/wyrd_sqlite/src/lib.rs @@ -0,0 +1,436 @@ +//! thin wrapper around the bits of sqlite we need. why not just use +//! [rusqlite](https://docs.rs/rusqlite) or similar? i ran into a minor annoyance with their API +//! and decided it would be fun to reinvent the wheel a bit and see how i could do. that said, this +//! implementation owes a great deal to theirs. +#![deny(clippy::undocumented_unsafe_blocks)] +#![deny(clippy::missing_safety_doc)] +#![deny(unsafe_op_in_unsafe_fn)] + +use std::{ + cell::RefCell, + ffi::{CString, c_char, c_int, c_uint}, + marker::PhantomData, + mem::ManuallyDrop, + ops::{Deref, DerefMut}, + ptr, slice, + str::FromStr, +}; + +use libsqlite3_sys as ffi; + +mod error; +mod from_sql; +mod params; + +pub use error::{Error, GetError, GetResult, Result}; +pub use from_sql::{FromSql, Value}; +pub use params::{Param, Params}; + +fn version_number() -> i32 { + // SAFETY: trivial wrapper + unsafe { ffi::sqlite3_libversion_number() } +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug)] + #[repr(C)] + pub struct OpenFlags: c_int { + const READONLY = ffi::SQLITE_OPEN_READONLY; + const READWRITE = ffi::SQLITE_OPEN_READWRITE; + const CREATE = ffi::SQLITE_OPEN_CREATE; + const URI = ffi::SQLITE_OPEN_URI; + const MEMORY = ffi::SQLITE_OPEN_MEMORY; + const NOMUTEX = ffi::SQLITE_OPEN_NOMUTEX; + const NOFOLLOW = ffi::SQLITE_OPEN_NOFOLLOW; + const EXRESCODE = ffi::SQLITE_OPEN_EXRESCODE; + } +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, Default)] + #[repr(C)] + pub struct PrepFlags: c_uint { + const PERSISTENT = ffi::SQLITE_PREPARE_PERSISTENT; + const NO_VTAB = ffi::SQLITE_PREPARE_NO_VTAB; + const DONT_LOG = ffi::SQLITE_PREPARE_DONT_LOG; + } +} + +impl Default for OpenFlags { + fn default() -> Self { + if version_number() > 3_037_000 { + Self::READWRITE | Self::CREATE | Self::NOMUTEX | Self::URI | Self::EXRESCODE + } else { + Self::READWRITE | Self::CREATE | Self::NOMUTEX | Self::URI + } + } +} + +pub struct Connection(RefCell<ConnectionInner>); + +impl Connection { + pub fn open(filename: &str) -> Result<Self> { + Self::open_with_flags(filename, OpenFlags::default()) + } + + pub fn open_with_flags(filename: &str, flags: OpenFlags) -> Result<Self> { + ConnectionInner::open(filename, flags) + .map(RefCell::new) + .map(Self) + } + + pub fn prepare(&self, src: &str) -> Result<(Statement<'_>, usize)> { + self.prepare_with_flags(src, PrepFlags::default()) + } + + pub fn prepare_with_flags( + &self, + src: &str, + flags: PrepFlags, + ) -> Result<(Statement<'_>, usize)> { + self.prepare_raw(src, flags) + .map(|(raw, read)| (Statement { raw, conn: self }, read)) + } + + pub fn prepare_raw(&self, src: &str, flags: PrepFlags) -> Result<(RawStatement, usize)> { + self.0.borrow_mut().prepare_raw(src, flags) + } + + fn get_error(&self) -> Error { + self.0.borrow_mut().get_error() + } + + fn decode_response(&self, code: c_int) -> Result<()> { + self.0.borrow_mut().decode_response(code) + } + + fn changes(&self) -> usize { + self.0.borrow_mut().changes() + } + + pub fn execute<P: Params>(&self, src: &str, params: P) -> Result<usize> { + let (mut stmt, read) = self.prepare(src)?; + if read != src.len() { + Err(Error::MultipleStatements) + } else { + stmt.execute(params) + } + } +} + +struct ConnectionInner { + db: *mut ffi::sqlite3, +} + +// SAFETY: ConnectionInner owns the underlying pointer, and all its methods are safe to call from +// different threads, just not more than one at a time. +unsafe impl Send for ConnectionInner {} + +impl ConnectionInner { + // we take &str and not &Path because 1. sqlite specifies that the argument should be + // valid utf-8 and 2. there are valid inputs that aren't actually paths + fn open(filename: &str, flags: OpenFlags) -> Result<Self> { + let c_filename = CString::from_str(filename)?; + + let mut db: *mut ffi::sqlite3 = std::ptr::null_mut(); + let r = + // SAFETY: we're mutating the pointer db, not dereferencing it, so we're good + unsafe { ffi::sqlite3_open_v2(c_filename.as_ptr(), &mut db, flags.bits(), std::ptr::null()) }; + + if r == ffi::SQLITE_OK { + Ok(Self { db }) + } else if db.is_null() { + Err(Error::from_code(r)) + } else { + let e = Error::from_db(db); + // SAFETY: db came from sqlite3_open_v2 and it's not null so we're good to close it + let r = unsafe { ffi::sqlite3_close(db) }; + debug_assert_eq!(r, ffi::SQLITE_OK); + Err(e) + } + } + + fn get_error(&mut self) -> Error { + Error::from_db(self.db) + } + + fn decode_response(&mut self, code: c_int) -> Result<()> { + if code == ffi::SQLITE_OK { + Ok(()) + } else { + Err(self.get_error()) + } + } + + fn changes(&mut self) -> usize { + // SAFETY: this is only ever called immediately after executing a statement, and we're only + // using the database connection from a single thread at a time, so the value should always + // be good. + (unsafe { ffi::sqlite3_changes(self.db) }) as usize + } + + fn prepare_raw(&mut self, src: &str, flags: PrepFlags) -> Result<(RawStatement, usize)> { + let mut stmt: *mut ffi::sqlite3_stmt = ptr::null_mut(); + let mut tail: *const c_char = ptr::null(); + + // SAFETY: we know self.db hasn't been closed because we only close when we drop, and we + // know &mut c_stmt isn't null because it's pointing to a thing we own. this upholds + // sqlite's requirements + self.decode_response(unsafe { + ffi::sqlite3_prepare_v3( + self.db, + src.as_ptr().cast(), + src.len() as c_int, + flags.bits(), + &mut stmt, + &mut tail, + ) + })?; + + if stmt.is_null() { + Err(Error::EmptyStatement) + } else { + // SAFETY: sqlite guarantees that tail will point into src, so we know its address is >= + let read = unsafe { tail.offset_from_unsigned(src.as_ptr().cast()) }; + Ok((RawStatement::new(stmt), read)) + } + } +} + +impl Drop for ConnectionInner { + fn drop(&mut self) { + // SAFETY: Connection always owns the underlying db, so can close it + let r = unsafe { ffi::sqlite3_close(self.db) }; + debug_assert_eq!(r, ffi::SQLITE_OK); + } +} + +pub struct RawStatement { + ptr: *mut ffi::sqlite3_stmt, +} + +impl Drop for RawStatement { + fn drop(&mut self) { + // SAFETY: RawStatement owns its pointer so this should never double free + unsafe { ffi::sqlite3_finalize(self.ptr) }; + } +} + +impl RawStatement { + fn new(ptr: *mut ffi::sqlite3_stmt) -> Self { + Self { ptr } + } + + /// # Safety + /// + /// the caller must ensure that `conn` is the database connection that was used to prepare the + /// statement. this isn't unsafe in a strict sense, since the outcome of failing to uphold this + /// invariant is just `MISUSE` errors for sqlite but still, don't do it! + pub unsafe fn with_conn<'a>(&'a mut self, conn: &'a Connection) -> BorrowedStatement<'a> { + // copying a type w/ a destructor is bad news, but we're immediately wrapping it to make + // sure the destructor won't be called twice. + let inner = ManuallyDrop::new(Statement { + raw: Self { ptr: self.ptr }, + conn, + }); + BorrowedStatement { + marker: PhantomData, + inner, + } + } +} + +pub struct Statement<'a> { + raw: RawStatement, + conn: &'a Connection, +} + +// lot of little wrapper functions that we know are safe because holding a borrow guarantees +// we're on the same thread as our database connection, and they don't have any safety requirements +// beyond that. +#[allow(clippy::undocumented_unsafe_blocks)] +impl<'a> Statement<'a> { + fn ptr(&self) -> *mut ffi::sqlite3_stmt { + self.raw.ptr + } + + fn step(&mut self) -> c_int { + unsafe { ffi::sqlite3_step(self.ptr()) } + } + + fn column_type(&mut self, i: c_int) -> Result<c_int> { + let count = unsafe { ffi::sqlite3_column_count(self.ptr()) }; + if i >= count || i < 0 { + Err(Error::InvalidColumn(i)) + } else { + Ok(unsafe { ffi::sqlite3_column_type(self.ptr(), i) }) + } + } + + fn reset_inner(&mut self) -> c_int { + unsafe { ffi::sqlite3_reset(self.ptr()) } + } + + fn reset(&mut self) -> Result<()> { + self.conn.decode_response(self.reset_inner()) + } + + fn bind_i32(&mut self, i: c_int, n: c_int) -> Result<()> { + let r = unsafe { ffi::sqlite3_bind_int(self.ptr(), i, n) }; + self.conn.decode_response(r) + } + + fn bind_i64(&mut self, i: c_int, n: i64) -> Result<()> { + let r = unsafe { ffi::sqlite3_bind_int64(self.ptr(), i, n) }; + self.conn.decode_response(r) + } + + fn bind_double(&mut self, i: c_int, f: f64) -> Result<()> { + let r = unsafe { ffi::sqlite3_bind_double(self.ptr(), i, f) }; + self.conn.decode_response(r) + } + + fn bind_null(&mut self, i: c_int) -> Result<()> { + let r = unsafe { ffi::sqlite3_bind_null(self.ptr(), i) }; + self.conn.decode_response(r) + } + + pub fn query<'q, P: Params>(&'q mut self, params: P) -> Result<Rows<'a, 'q>> { + params.__bind_in(self)?; + Ok(Rows::new(self)) + } + + pub fn execute<P: Params>(&mut self, params: P) -> Result<usize> { + params.__bind_in(self)?; + match self.step() { + ffi::SQLITE_DONE => { + self.reset()?; + Ok(self.conn.changes()) + } + ffi::SQLITE_ROW => { + self.reset_inner(); + Err(Error::ExecReturnedRows) + } + _ => Err(self.conn.get_error()), + } + } +} + +pub struct BorrowedStatement<'conn> { + marker: PhantomData<&'conn mut RawStatement>, + // using the raw pointer contained in here as a secret &'a mut RawStatement so we can make Deref + // happy but that means we REALLY DONT WANNA DROP THIS bc we don't *actually* own the + // underlying pointer + inner: ManuallyDrop<Statement<'conn>>, +} + +impl<'conn> Deref for BorrowedStatement<'conn> { + type Target = Statement<'conn>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for BorrowedStatement<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +pub struct Rows<'conn, 'stmt> { + stmt: &'stmt mut Statement<'conn>, + // hate that this is necessary. should be able to get away with making stmt optional but get + // stuck in borrow checker hell + finished: bool, +} + +impl<'a, 'q> Rows<'a, 'q> { + fn new(stmt: &'q mut Statement<'a>) -> Self { + Self { + stmt, + finished: false, + } + } + + pub fn try_next_row<'row>(&'row mut self) -> Result<Option<Row<'a, 'row>>> { + if self.finished { + return Ok(None); + } + + match self.stmt.step() { + ffi::SQLITE_ROW => Ok(Some(Row { stmt: self.stmt })), + ffi::SQLITE_DONE => { + self.finished = true; + self.stmt.reset()?; + Ok(None) + } + _ => { + self.stmt.reset_inner(); + Err(self.stmt.conn.get_error()) + } + } + } +} + +pub struct Row<'conn, 'row> { + stmt: &'row mut Statement<'conn>, +} + +impl<'row> Row<'_, 'row> { + // should potentially add a "raw" variant of this that lets the caller just accept the sqlite + // dynamic conversions + fn get_value(&mut self, i: c_int) -> Result<Value<'row>> { + let ptr = self.stmt.ptr(); + + Ok(match self.stmt.column_type(i)? { + ffi::SQLITE_INTEGER => Value::Int( + // SAFETY: we've verified the type, ptr is valid, etc. + unsafe { ffi::sqlite3_column_int64(ptr, i) }, + ), + ffi::SQLITE_FLOAT => Value::Float( + // SAFETY: same as above + unsafe { ffi::sqlite3_column_double(ptr, i) }, + ), + ffi::SQLITE_TEXT => { + // SAFETY: same as above + let data = unsafe { ffi::sqlite3_column_text(ptr, i) }; + if data.is_null() { + Value::Null + } else { + // SAFETY: going through the criteria for from_raw_parts we have + // 1. we know data is non-null, and sqlite has assured us it's valid for len + // bytes + // 2. utf-8 strings have no alignment requirements so that's fine + // 3. our lifetime restrictions should prevent the statement from being stepped + // again before this reference goes out of scope, so it shouldn't be mutated + Value::Text(str::from_utf8(unsafe { + let len = ffi::sqlite3_column_bytes(ptr, i) as usize; + slice::from_raw_parts(data, len) + })?) + } + } + ffi::SQLITE_BLOB => { + // SAFETY: same as above + let data: *const u8 = unsafe { ffi::sqlite3_column_blob(ptr, i) }.cast(); + if data.is_null() { + Value::Null + } else { + // SAFETY: same as above + Value::Blob(unsafe { + let len = ffi::sqlite3_column_bytes(ptr, i) as usize; + slice::from_raw_parts(data, len) + }) + } + } + ffi::SQLITE_NULL => Value::Null, + _ => unreachable!(), + }) + } + + pub fn get<T: FromSql>(&mut self, col: c_int) -> GetResult<T> { + let val = self.get_value(col)?; + + T::try_from_sql(val).map_err(GetError::FromSql) + } +} diff --git a/wyrd_sqlite/src/params.rs b/wyrd_sqlite/src/params.rs new file mode 100644 index 0000000..4561c40 --- /dev/null +++ b/wyrd_sqlite/src/params.rs @@ -0,0 +1,92 @@ +use std::ffi::{c_double, c_int}; + +use variadics_please::all_tuples_enumerated; + +use crate::{Result, Statement, ffi}; + +mod sealed { + pub trait Sealed {} +} + +use sealed::Sealed; + +pub trait Param: Sealed { + #[doc(hidden)] + fn __bind_in(&self, stmt: &mut Statement<'_>, i: c_int) -> Result<()>; +} + +pub trait Params: Sealed { + #[doc(hidden)] + fn __bind_in(&self, stmt: &mut Statement<'_>) -> Result<()>; +} + +impl Sealed for c_int {} +impl Param for c_int { + #[inline] + fn __bind_in(&self, stmt: &mut Statement<'_>, i: c_int) -> Result<()> { + stmt.bind_i32(i, *self) + } +} + +impl Sealed for ffi::sqlite3_int64 {} +impl Param for ffi::sqlite3_int64 { + #[inline] + fn __bind_in(&self, stmt: &mut Statement<'_>, i: c_int) -> Result<()> { + stmt.bind_i64(i, *self) + } +} + +impl Sealed for c_double {} +impl Param for c_double { + #[inline] + fn __bind_in(&self, stmt: &mut Statement<'_>, i: c_int) -> Result<()> { + stmt.bind_double(i, *self) + } +} + +impl<T: Param> Sealed for Option<T> {} +impl<T: Param> Param for Option<T> { + fn __bind_in(&self, stmt: &mut Statement<'_>, i: c_int) -> Result<()> { + match self.as_ref() { + Some(v) => v.__bind_in(stmt, i), + // explicitly binding NULL is required in case a previous call bound a value here + None => stmt.bind_null(i), + } + } +} + +impl Sealed for () {} +impl Param for () { + #[inline] + fn __bind_in(&self, stmt: &mut Statement<'_>, i: c_int) -> Result<()> { + stmt.bind_null(i) + } +} + +impl Params for () { + fn __bind_in(&self, _stmt: &mut Statement<'_>) -> Result<()> { + Ok(()) + } +} + +macro_rules! impl_params { + ($(($i:tt, $T:ident)),*) => { + impl<$($T:Param),*> Sealed for ($($T,)*) {} + impl<$($T:Param),*> Params for ($($T,)*) { + fn __bind_in(&self, stmt: &mut Statement<'_>) -> Result<()> { + $(self.$i.__bind_in(stmt, $i + 1)?;)* + Ok(()) + } + } + + impl<$($T:Param),*> Sealed for ($(($T, c_int),)*) {} + impl<$($T:Param),*> Params for ($(($T, c_int),)*) { + fn __bind_in(&self, stmt: &mut Statement<'_>) -> Result<()> { + $(self.$i.0.__bind_in(stmt, self.$i.1)?;)* + Ok(()) + } + } + } +} + +all_tuples_enumerated!(impl_params, 1, 12, T); |