diff --git a/Cargo.lock b/Cargo.lock index 4c1bf2a..7ecc085 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -324,6 +324,12 @@ dependencies = [ "slab", ] +[[package]] +name = "generativity" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c81fb5260e37854d09d5c87183309fd8c555b75289427884b25660bc87a85e" + [[package]] name = "getrandom" version = "0.4.2" @@ -555,6 +561,7 @@ dependencies = [ "bindgen", "cc", "criterion", + "generativity", "glob", "libc", "pg_query", diff --git a/Cargo.toml b/Cargo.toml index cea3270..551bce9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ edition = "2024" field_offset_assertions = [] [dependencies] +generativity = "1.2.1" libc = "0.2.186" thiserror = "2.0.18" diff --git a/build.rs b/build.rs index fbf4136..ca493f6 100644 --- a/build.rs +++ b/build.rs @@ -48,10 +48,14 @@ fn main() { let mut node_bindings = String::new(); bindgen .clone() - // This is another name for Node + // Exclude Node types that aren't parse nodes and would require + // additional logic to support + .blocklist_type("Const") .blocklist_type("Expr") - // This is yet another name for Node + .blocklist_type("JsonTablePath") .blocklist_type("JsonTablePlan") + .blocklist_type("RelabelType") + .blocklist_type("Var") // Yes, we want doc comments .clang_arg("-fparse-all-comments") .derive_debug(false) @@ -60,10 +64,38 @@ fn main() { // SAFETY: YOLO .write(Box::new(unsafe { node_bindings.as_mut_vec() })) .unwrap(); + let node_structs = generate_node_structs(&node_bindings, &out_dir.join("nodes_raw.rs")).unwrap(); generate_node_enum(&node_structs, &out_dir.join("node_enum_raw.rs")).unwrap(); + let mut makefunc_bindings = String::new(); + bindgen + .clone() + .allowlist_file( + c_dir + .join("src/postgres/include/nodes/makefuncs.h") + .to_str() + .unwrap(), + ) + .allowlist_file( + c_dir + .join("src/postgres/include/nodes/value.h") + .to_str() + .unwrap(), + ) + .blocklist_item("makeDefElemExtended") // This type has multiple makefuncs + .blocklist_item("makeColumnDef") // Has more logic than we want + .blocklist_item("makeTypeNameFromOid") // Parser doesn't know OIDs + .blocklist_item("makeTypeName") // We map to the list form, not unqualified + .blocklist_item("makeSimpleA_Expr") // We map to the list form, not unqualified + .generate() + .unwrap() + // SAFETY: YOLO + .write(Box::new(unsafe { makefunc_bindings.as_mut_vec() })) + .unwrap(); + let make_funcs = generate_make_funcs(&makefunc_bindings, &node_structs, &out_dir).unwrap(); + let mut bindgen = bindgen .allowlist_item("Node") .allowlist_item("MemoryContext") @@ -86,11 +118,16 @@ fn main() { ) .allowlist_item("StringInfo") .allowlist_item("wrapped_raw_deparse") + .allowlist_item("wrapped_pnstrdup") + .allowlist_item("list_copy") .wrap_static_fns(true) .wrap_static_fns_path(out_dir.join("wrap_static_fns")); for struct_name in &node_structs { bindgen = bindgen.blocklist_item(struct_name.name.to_string()); } + for make_func_name in &make_funcs { + bindgen = bindgen.allowlist_item(make_func_name); + } bindgen .generate() .unwrap() @@ -146,7 +183,7 @@ impl NodeField { if let NodeFieldType::Primitive(_) = &self.ty { parse_quote!(pub) } else { - parse_quote!(pub(crate)) + syn::Visibility::Inherited } } @@ -155,22 +192,23 @@ impl NodeField { let fattrs = &self.attrs; let fname = &self.name; + let ret = self.ty(&parse_quote!('_)); match &self.ty { Private(_) | Primitive(_) => None, Node => Some(parse_quote! { #(#fattrs)* #[inline] - pub fn #fname(&self) -> crate::Node<'_> { + pub fn #fname(&self) -> #ret { // SAFETY: The lifetime is not longer than self unsafe { crate::Node::from_ptr(self.#fname) } } }), - CastNode(ty) => Some(parse_quote! { + CastNode(_) => Some(parse_quote! { #(#fattrs)* #[inline] - pub fn #fname(&self) -> Option<&#ty> { + pub fn #fname(&self) -> #ret { // SAFETY: Pointer will always be valid or NULL unsafe { self.#fname.as_ref() } } @@ -179,17 +217,17 @@ impl NodeField { List => Some(parse_quote! { #(#fattrs)* #[inline] - pub fn #fname(&self) -> &crate::list::NodeList { + pub fn #fname(&self) -> #ret { // SAFETY: The lifetime is not longer than self unsafe { crate::Node::from_ptr(self.#fname.cast()) } .expect_node_list() } }), - CastList(ty) => Some(parse_quote! { + CastList(_) => Some(parse_quote! { #(#fattrs)* #[inline] - pub fn #fname(&self) -> &crate::list::CastNodeList<&#ty> { + pub fn #fname(&self) -> #ret { // SAFETY: The lifetime is not longer than self unsafe { crate::Node::from_ptr(self.#fname.cast()) } .expect_node_list() @@ -200,7 +238,7 @@ impl NodeField { CString => Some(parse_quote! { #(#fattrs)* #[inline] - pub fn #fname(&self) -> Option<&str> { + pub fn #fname(&self) -> #ret { if self.#fname.is_null() { None } else { @@ -217,7 +255,7 @@ impl NodeField { ConstVal => Some(parse_quote! { #(#fattrs)* #[inline] - pub fn #fname(&self) -> Option> { + pub fn #fname(&self) -> #ret { if self.isnull { None } else { @@ -243,8 +281,34 @@ impl NodeField { parse_quote!(#debug_expr.field(stringify!(#fname), #value_expr)) } - fn ty(&self) -> syn::Type { - self.ty.ty() + fn raw_ty(&self) -> syn::Type { + self.ty.raw_ty() + } + + fn ty(&self, lifetime: &syn::Lifetime) -> syn::Type { + self.ty.ty(lifetime) + } + + fn constructor_ty(&self, lifetime: &syn::Lifetime) -> syn::Type { + self.ty.constructor_ty(lifetime) + } + + fn as_raw_expr(&self) -> syn::Expr { + use NodeFieldType::*; + + let fname = &self.name; + match self.ty { + Private(_) | Primitive(_) => parse_quote!(#fname), + Node | CastNode(_) | List | CastList(_) => parse_quote!(#fname.into_ptr().cast()), + CString => parse_quote! { + #fname + .map(|s| raw::wrapped_pnstrdup(s.as_ptr().cast(), s.len())) + .unwrap_or(ptr::null_mut()) + }, + ConstVal => parse_quote!(compile_error!( + "PG has no functions that take ValUnion by value" + )), + } } } @@ -260,7 +324,7 @@ enum NodeFieldType { } impl NodeFieldType { - fn ty(&self) -> syn::Type { + fn raw_ty(&self) -> syn::Type { match self { Self::Private(t) | Self::Primitive(t) => t.clone(), Self::Node => parse_quote!(*mut Node), @@ -270,6 +334,33 @@ impl NodeFieldType { Self::ConstVal => parse_quote!(ValUnion), } } + + fn ty(&self, lifetime: &syn::Lifetime) -> syn::Type { + match self { + Self::Private(t) | Self::Primitive(t) => t.clone(), + Self::Node => parse_quote!(crate::Node<#lifetime>), + Self::CastNode(t) => parse_quote!(Option<&#lifetime crate::nodes::#t>), + Self::List => parse_quote!(&#lifetime crate::list::NodeList), + Self::CastList(t) => parse_quote!(&#lifetime crate::list::CastNodeList<&#lifetime #t>), + Self::CString => parse_quote!(Option<&#lifetime str>), + Self::ConstVal => parse_quote!(Option>), + } + } + + fn constructor_ty(&self, lifetime: &syn::Lifetime) -> syn::Type { + match self { + Self::Private(t) | Self::Primitive(t) => t.clone(), + Self::Node => parse_quote!(Unique<#lifetime, raw::Node>), + Self::CastNode(t) => parse_quote!(Unique<#lifetime, crate::nodes::#t>), + Self::List => parse_quote!(Unique<#lifetime, crate::list::NodeList>), + Self::CastList(t) => { + parse_quote!(Unique<#lifetime, crate::list::CastNodeList<&#lifetime #t>>) + } + // Strings get copied in constructors so we can ignore the input LT + Self::CString => parse_quote!(Option<&str>), + Self::ConstVal => parse_quote!(Option>), + } + } } /// Generates the structs for each node and writes them to the given path. @@ -380,7 +471,7 @@ fn generate_node_structs( let fattrs = s.fields.iter().map(|f| &f.attrs); let fvis = s.fields.iter().map(|f| f.vis()); let fnames = s.fields.iter().map(|f| &f.name); - let ftys = s.fields.iter().map(|f| f.ty()); + let ftys = s.fields.iter().map(|f| f.raw_ty()); out_file.items.push(parse_quote! { #(#sattrs)* pub struct #sname { @@ -464,7 +555,7 @@ fn generate_node_enum( /// SAFETY: The caller is responsible for ensuring the provided /// lifetime does not outlast the memory context this Node was /// allocated in - pub(crate) unsafe fn from_ptr(ptr: *mut raw::Node) -> Self { + pub unsafe fn from_ptr(ptr: *mut raw::Node) -> Self { // SAFETY: PG will never return an invalid Node other than NULL // and the caller is ensuring a valid lifetime unsafe { ptr.as_ref() }.map(|p| { @@ -496,6 +587,132 @@ fn generate_node_enum( Ok(()) } +/// Returns the function names needed in `raw` +fn generate_make_funcs( + bindings: &str, + node_structs: &[NodeStruct], + out_dir: &Path, +) -> Result, Box> { + let file = syn::parse_file(bindings)?; + let mut out_file = syn::File { + shebang: None, + attrs: Vec::new(), + items: Vec::new(), + }; + + let makefuncs = file + .items + .into_iter() + .flat_map(|i| match i { + syn::Item::ForeignMod(f) => f.items, + _ => Vec::new(), + }) + .filter_map(|i| match i { + syn::ForeignItem::Fn(f) => Some(f), + _ => None, + }) + .filter_map(|f| { + if f.sig.ident.to_string().starts_with("make") + && let syn::ReturnType::Type(_, t) = &f.sig.output + && let syn::Type::Ptr(t) = &**t + && let syn::Type::Path(syn::TypePath { path, .. }) = &*t.elem + && let Some(s) = node_structs + .iter() + .find(|s| path.get_ident() == Some(&s.name)) + { + Some((s, f)) + } else { + None + } + }) + .collect::>(); + + for (node, makefunc) in &makefuncs { + let lt = parse_quote!('a); + let node_name = &node.name; + let func_name = syn::Ident::new(&format!("make_{}", node_name), makefunc.sig.ident.span()); + let raw_func_name = &makefunc.sig.ident; + + let arg_fields = makefunc + .sig + .inputs + .iter() + .filter_map(|arg| match arg { + syn::FnArg::Typed(pat_type) => Some(pat_type), + _ => None, + }) + .filter_map(|arg| { + /// The arity of the constructor functions sometimes varies + /// wildly from the number of fields present on the struct. + /// Because of that, we get the field an argument maps to by + /// name instead of index. But in a handful of cases, those + /// names don't match up, so we have a hard coded list of + /// corrections + static MISMATCHED_FIELD_NAMES: &[((&str, &str), &str)] = &[ + (("BitString", "str_"), "bsval"), + (("Boolean", "val"), "boolval"), + (("DefElem", "name"), "defname"), + (("Float", "numericStr"), "fval"), + (("FuncCall", "name"), "funcname"), + (("FuncExpr", "fformat"), "funcformat"), + (("FuncExpr", "rettype"), "funcresulttype"), + (("Integer", "i"), "ival"), + (("JsonTablePath", "pathname"), "name"), + (("JsonTablePath", "pathvalue"), "value"), + (("JsonTablePathSpec", "string_location"), "location"), + (("String", "str_"), "sval"), + ]; + + let syn::Pat::Ident(syn::PatIdent { ident: arg, .. }) = &*arg.pat else { + return None; + }; + let arg = MISMATCHED_FIELD_NAMES + .iter() + .find_map(|((sname, argname), fname)| { + (node_name == sname && *arg == argname).then(|| (*fname).to_owned()) + }) + .unwrap_or_else(|| arg.to_string()); + + node.fields.iter().find(|f| f.name == arg) + }) + .collect::>(); + + let fargs = arg_fields.iter().map(|field| -> syn::FnArg { + let fname = &field.name; + let fty = field.constructor_ty(<); + parse_quote!(#fname: #fty) + }); + let farg_exprs = arg_fields.iter().map(|field| field.as_raw_expr()); + + out_file.items.push(parse_quote! { + // FIXME(sage): Change to pub(crate) when we have a way to write a compile-fail + // test for invariant lifetimes without making this pub + #[doc(hidden)] + #[allow(non_snake_case)] + pub fn #func_name<#lt>(mem: MemoryToken<#lt>, #(#fargs,)*) -> Unique<#lt, crate::nodes::#node_name> { + // SAFETY: The given closure never panics. The function raw + // functions we call are only allocating and assigning fields. + // They have no error conditions, so we can never longjmp + // over Rust frames. We have explicitly taken a mut reference + // to MemoryContext to ensure the lifetime is invariant + let ptr = unsafe { mem.mem.within(|| { + &mut *raw::#raw_func_name(#(#farg_exprs),*) + }) }; + Unique(Some(ptr), mem.id) + } + }) + } + + std::fs::write( + out_dir.join("make_funcs_raw.rs"), + prettyplease::unparse(&out_file), + )?; + Ok(makefuncs + .into_iter() + .map(|(_, f)| f.sig.ident.to_string()) + .collect()) +} + fn is_flexible_array_ty(ty: &syn::Type) -> bool { matches!( ty, diff --git a/src/const_val.rs b/src/const_val.rs index 91448ae..36c1fdd 100644 --- a/src/const_val.rs +++ b/src/const_val.rs @@ -96,26 +96,20 @@ impl<'a> ConstValue<'a> { #[cfg(test)] mod tests { use super::*; - use crate::nodes; + use crate::make::{make_Boolean, make_Float, make_Integer, memory_token}; + use crate::mem::MemoryContext; #[test] fn test_integer_value() { - let smallint = nodes::Integer { - type_: NodeTag_T_Integer, - ival: 1, - }; - let bigint = nodes::Float { - type_: NodeTag_T_Float, - fval: c"1234567890".as_ptr().cast_mut(), - }; - let boolval = nodes::Boolean { - type_: NodeTag_T_Boolean, - boolval: true, - }; + let mem = MemoryContext::new(c"test_integer_value"); + memory_token!(mem); + let smallint = make_Integer(mem, 1); + let bigint = make_Float(mem, Some("1234567890")); + let boolval = make_Boolean(mem, true); - let smallunion = unsafe { *(&raw const smallint).cast() }; - let bigunion = unsafe { *(&raw const bigint).cast() }; - let boolunion = unsafe { *(&raw const boolval).cast() }; + let smallunion = unsafe { *(smallint.into_ptr()).cast() }; + let bigunion = unsafe { *(bigint.into_ptr()).cast() }; + let boolunion = unsafe { *(boolval.into_ptr()).cast() }; let smallval = ConstValue::from_raw(&smallunion); let bigval = ConstValue::from_raw(&bigunion); diff --git a/src/lib.rs b/src/lib.rs index 055898e..868b4fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,24 @@ #![cfg_attr(feature = "field_offset_assertions", feature(offset_of_enum))] use std::{ffi, fmt, ptr}; +#[doc(hidden)] +pub use generativity::make_guard; + pub mod const_val; mod deparse; pub mod error; pub mod list; -mod mem; +// FIXME(sage): Change to pub(crate) when we have a way to write a compile-fail +// test for invariant lifetimes without making this pub +#[doc(hidden)] +pub mod make; +// FIXME(sage): Change to pub(crate) when we have a way to write a compile-fail +// test for invariant lifetimes without making this pub +#[doc(hidden)] +pub mod mem; pub mod node_enum; pub mod nodes; mod pg_error; -#[allow(warnings)] pub mod raw; pub mod walk; diff --git a/src/make.rs b/src/make.rs new file mode 100644 index 0000000..1b45434 --- /dev/null +++ b/src/make.rs @@ -0,0 +1,99 @@ +use crate::mem::MemoryContext; +use crate::raw::{self, *}; +use generativity::Id; +use std::ptr; + +include!(concat!(env!("OUT_DIR"), "/make_funcs_raw.rs")); + +/// FIXME(sage): These tests don't assert that the failures are lifetime +/// related +/// +/// ```compile_fail +/// use pg_raw_parse::mem::MemoryContext; +/// use pg_raw_parse::make::{memory_token, make_String, make_List}; +/// +/// let mem1 = MemoryContext::new(c"mem1"); +/// let mem2 = MemoryContext::new(c"mem2"); +/// memory_token!(mem1); +/// memory_token!(mem2); +/// let node = make_String(mem2, Some("hi")); +/// let _list = make_List(mem1, &[node]); // Fails, node is on mem2 +/// ``` +/// +/// ``` +/// use pg_raw_parse::mem::MemoryContext; +/// use pg_raw_parse::make::{memory_token, make_String, make_List}; +/// +/// let mem1 = MemoryContext::new(c"mem1"); +/// let mem2 = MemoryContext::new(c"mem2"); +/// memory_token!(mem1); +/// memory_token!(mem2); +/// let node = make_String(mem1, Some("hi")); +/// let _list = make_List(mem1, &[node]); // Is fine, both nodes are on mem1 +/// ``` +// FIXME(sage): Change to pub(crate) when we have a way to write a compile-fail +// test for invariant lifetimes without making this pub +#[doc(hidden)] +#[derive(Clone, Copy)] +pub struct MemoryToken<'a> { + #[doc(hidden)] + pub mem: &'a MemoryContext, + #[doc(hidden)] + pub id: Id<'a>, +} + +#[macro_export] +macro_rules! memory_token { + ($mem:ident) => { + $crate::make_guard!(a); + let $mem = $crate::make::MemoryToken { + mem: &$mem, + id: a.into(), + }; + }; +} + +// FIXME(sage): Change to pub(crate) when we have a way to write a compile-fail +// test for invariant lifetimes without making this pub +#[doc(hidden)] +pub use memory_token; + +/// A uniquely owned pointer to a node. This is effectively `Box`, but +/// constrained to the lifetime of its memory context. +#[repr(C)] +pub struct Unique<'a, T>(Option<&'a mut T>, Id<'a>); + +impl<'a, T> Unique<'a, T> { + /// Consume this to get the inner raw node pointer, erasing its lifetime. + /// The returned pointer should either be stored along side the memory + /// context, or assigned to the field of a node within the same memory + /// context. + pub(crate) fn into_ptr(self) -> *mut raw::Node { + self.0.map(ptr::from_mut).unwrap_or(ptr::null_mut()).cast() + } +} + +// FIXME(sage): Change to pub(crate) when we have a way to write a compile-fail +// test for invariant lifetimes without making this pub +#[doc(hidden)] +#[allow(non_snake_case)] +pub fn make_List<'a, T>( + mem: MemoryToken<'a>, + elems: &[Unique<'a, T>], +) -> Unique<'a, crate::list::NodeList> { + if elems.is_empty() { + Unique(None, mem.id) + } else { + let list_to_copy = raw::List { + type_: raw::NodeTag_T_List, + length: elems.len() as _, + max_length: elems.len() as _, + elements: elems.as_ptr().cast_mut().cast(), + initial_elements: raw::__IncompleteArrayField::new(), + }; + // SAFETY: The given closure never panics, we're passing valid pointers + let ptr = unsafe { mem.mem.within(|| raw::list_copy(&raw const list_to_copy)) }; + // SAFETY: The returned pointer is always a palloc'd list pointer + Unique(Some(unsafe { &mut *ptr.cast() }), mem.id) + } +} diff --git a/src/mem.rs b/src/mem.rs index d4f8b51..08e435a 100644 --- a/src/mem.rs +++ b/src/mem.rs @@ -1,10 +1,16 @@ use crate::raw; use std::ffi::CStr; -pub(crate) struct MemoryContext(raw::MemoryContext); +// FIXME(sage): Change to pub(crate) when we have a way to write a compile-fail +// test for invariant lifetimes without making this pub +#[doc(hidden)] +pub struct MemoryContext(raw::MemoryContext); impl MemoryContext { - pub(crate) fn new(name: &'static CStr) -> Self { + // FIXME(sage): Change to pub(crate) when we have a way to write a compile-fail + // test for invariant lifetimes without making this pub + #[doc(hidden)] + pub fn new(name: &'static CStr) -> Self { // SAFETY: No documented invariants unsafe { raw::pg_query_init(); diff --git a/src/node_enum.rs b/src/node_enum.rs index 9ad42bf..05e6609 100644 --- a/src/node_enum.rs +++ b/src/node_enum.rs @@ -34,19 +34,14 @@ impl<'a> Node<'a> { #[test] fn test_node_as_list() { - let int = nodes::Integer { - type_: raw::NodeTag_T_Integer, - ival: 1, - }; - let mut ptr_to_int = &raw const int; - let mut list = raw::List { - type_: raw::NodeTag_T_List, - length: 1, - max_length: 1, - elements: &raw mut ptr_to_int as *mut raw::ListCell, - initial_elements: raw::__IncompleteArrayField::new(), - }; - let node = unsafe { Node::from_ptr(&raw mut list as _) }; + use crate::make::*; + use crate::mem::MemoryContext; + + let mem = MemoryContext::new(c"test_node_as_list"); + memory_token!(mem); + let int = make_Integer(mem, 1); + let list = make_List(mem, &[int]); + let node = unsafe { Node::from_ptr(list.into_ptr()) }; let actual = node.expect_node_list().into_iter().collect::>(); std::assert_matches!(actual[..], [Node::Integer(nodes::Integer { ival: 1, .. })]); } diff --git a/src/raw.rs b/src/raw.rs index c32d5ee..853833d 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -1,6 +1,7 @@ //! Functions in this module should never be called unless they have been //! manually wrapped *IN C* with `PG_TRY()` and `PG_CATCH()`. PG errors use //! `longjmp`, and jumping over any Rust frames is undefined behavior. +#![allow(warnings)] use crate::nodes::*; @@ -51,11 +52,21 @@ fn test_raw_node_bindings_arent_generated() { .collect::>(); node_structs.sort(); - // We need the raw binding to Node for tag checking, List and - // MemoryContextData are both their own thing, Expr is just an alias for - // Node + // These are the nodes that we either have special handling for, or have + // explicitly blocklisted because they aren't parse nodes and handling + // them would require extra code assert_eq!( node_structs, - &["Expr", "JsonTablePlan", "List", "MemoryContextData", "Node"] + &[ + "Const", // A_Const is the parsed version + "Expr", // Abstract type + "JsonTablePath", // JsonTablePathSpec is the parsed version + "JsonTablePlan", // JsonTablePlanSpec is the parsed version + "List", // list::NodeList + "MemoryContextData", // mem::MemoryContext + "Node", // node_enum::Node + "RelabelType", // Implicit coercion, never parsed + "Var", // Used during optimization, not parsing + ], ); } diff --git a/wrapper.h b/wrapper.h index 9f7f095..ce94dbf 100644 --- a/wrapper.h +++ b/wrapper.h @@ -3,6 +3,7 @@ #include "src/pg_query_internal.h" #include "nodes/parsenodes.h" #include "nodes/nodeFuncs.h" +#include "nodes/makefuncs.h" #include "utils/palloc.h" #include "utils/memutils.h" #include "copy_pg_error.h" @@ -33,3 +34,19 @@ static inline StringInfo wrapped_raw_deparse(RawStmt *stmt, ErrorData **error) { PG_END_TRY(); return str; } + +// FIXME(sage): libpg_query doesn't compile pnstrdup, which we want +static inline +char * +wrapped_pnstrdup(const char *in, Size len) +{ + char *out; + + len = strnlen(in, len); + + out = palloc(len + 1); + memcpy(out, in, len); + out[len] = '\0'; + + return out; +}