Compare commits

...

3 Commits

Author SHA1 Message Date
34726d4f11 Add api_fn macro 2025-10-20 18:11:09 -04:00
c889480913 Add proc_macro2 2025-10-20 18:08:35 -04:00
3466212de7 Move db_type to separate module 2025-10-20 16:11:07 -04:00
5 changed files with 389 additions and 121 deletions

5
Cargo.lock generated
View File

@@ -6,15 +6,16 @@ version = 4
name = "libretunes_macro" name = "libretunes_macro"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"proc-macro2",
"quote", "quote",
"syn", "syn",
] ]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.95" version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]

View File

@@ -7,5 +7,6 @@ edition = "2024"
proc-macro = true proc-macro = true
[dependencies] [dependencies]
proc-macro2 = "1.0.101"
quote = "1.0.40" quote = "1.0.40"
syn = { version = "2.0.101", features = ["full"] } syn = { version = "2.0.101", features = ["full"] }

253
src/api_fn.rs Normal file
View File

@@ -0,0 +1,253 @@
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{
*,
punctuated::Punctuated,
};
struct Flags {
upload: bool,
admin: bool,
login: bool,
endpoint: String,
}
struct FnArgFlags {
need_db_conn: bool,
need_user: bool,
user_type: Option<Box<Type>>,
need_state: bool,
}
pub fn api_fn_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_args = parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
let mut input_fn = parse_macro_input!(item as ItemFn);
let flags = match parse_attributes(&attr_args) {
Ok(flags) => flags,
Err(e) => return e.to_compile_error().into(),
};
// Get function signature and original name
let sig = &input_fn.sig;
let orig_name = &sig.ident;
let inner_name = format_ident!("{}_inner", orig_name);
// Process function inputs and identify special parameters
let (outer_inputs, fn_arg_flags, call_args) = match process_function_inputs(sig, &flags) {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
// Generate pre-processing code
let pre_compute = generate_precompute_code(&flags, &fn_arg_flags);
// Build outer signature
let mut outer_sig = sig.clone();
outer_sig.inputs = outer_inputs;
// Rename original function
input_fn.sig.ident = inner_name.clone();
// Build server macro
let server_macro = build_server_macro(&flags, &attr_args);
quote! {
#[cfg(feature = "ssr")]
#input_fn
#server_macro
pub #outer_sig {
#(#pre_compute)*
#inner_name(#call_args).await
}
}.into()
}
/// Parse attributes to `api_fn` into the `Flags` struct
fn parse_attributes(attr_args: &Punctuated<Meta, Token![,]>) -> Result<Flags> {
let mut endpoint = None;
let mut upload_flag = false;
let mut admin_flag = false;
let mut login_flag = false;
for arg in attr_args {
match arg {
Meta::NameValue(name_value) => {
if name_value.path.is_ident("endpoint") {
match &name_value.value {
Expr::Lit(expr_lit) => {
if let Lit::Str(lit_str) = &expr_lit.lit {
endpoint = Some(lit_str.value());
} else {
return Err(Error::new_spanned(&name_value.value, "endpoint must be a string literal"));
}
}
_ => return Err(Error::new_spanned(&name_value.value, "Invalid value for endpoint")),
}
}
}
Meta::Path(path) => {
if path.is_ident("upload") {
upload_flag = true;
} else if path.is_ident("admin") {
admin_flag = true;
} else if path.is_ident("login") {
login_flag = true;
}
}
_ => {}
}
}
let endpoint = endpoint.ok_or_else(|| Error::new_spanned(attr_args, "Must specify API endpoint"))?;
Ok(Flags {
upload: upload_flag,
admin: admin_flag,
login: login_flag,
endpoint
})
}
/// Generate inputs to outer function, flags for preprocessing, and inner function call arguments
fn process_function_inputs(sig: &Signature, flags: &Flags) -> Result<(Punctuated<FnArg, Token![,]>, FnArgFlags, TokenStream2)> {
let mut outer_inputs = Punctuated::new();
let mut call_args = vec![];
let mut need_db_conn = false;
let mut need_user = flags.admin || flags.login;
let mut need_state = false;
let mut user_type = None;
for arg in &sig.inputs {
match arg {
FnArg::Typed(pat_type) => {
if let Pat::Ident(PatIdent { ident, .. }) = &*pat_type.pat {
let name = ident.to_string();
if name == "db_conn" {
need_db_conn = true;
need_state = true;
} else if name == "user" {
need_user = true;
user_type = Some(pat_type.ty.clone());
} else if name == "state" {
need_state = true;
} else {
outer_inputs.push(arg.clone());
}
call_args.push(quote! { #ident });
} else {
// Handle non-ident patterns
outer_inputs.push(arg.clone());
call_args.push(quote! { #arg });
}
}
FnArg::Receiver(_) => {
outer_inputs.push(arg.clone());
call_args.push(quote! { self });
}
}
}
let fn_args_flags = FnArgFlags {
need_db_conn,
need_user,
user_type,
need_state,
};
let call_args_stream = quote! { #(#call_args),* };
Ok((outer_inputs, fn_args_flags, call_args_stream))
}
fn generate_precompute_code(flags: &Flags, fn_args: &FnArgFlags) -> Vec<proc_macro2::TokenStream> {
let mut pre_compute = Vec::new();
if fn_args.need_state {
let state_err_msg = format!("Failed to get backend state for endpoint `{}`", flags.endpoint);
pre_compute.push(quote! {
let state_result = crate::util::backend_state::BackendState::get().await;
let state = crate::util::error::Contextualize::context(state_result, #state_err_msg)?;
});
}
if fn_args.need_db_conn {
let db_conn_err_msg = format!("Failed to get database connection for endpoint `{}`", flags.endpoint);
pre_compute.push(quote! {
let mut db_conn_result = state.get_db_conn();
let mut db_conn = crate::util::error::Contextualize::context(db_conn_result, #db_conn_err_msg)?;
let db_conn = &mut db_conn;
});
}
let user_type = if let Some(ty) = &fn_args.user_type {
quote! { #ty }
} else {
quote! { Option<crate::models::backend::User> }
};
if fn_args.need_user {
pre_compute.push(quote! {
let user = <#user_type as crate::api::auth::LoggedInUser>::get().await?;
});
}
if flags.admin {
pre_compute.push(quote! {
let is_admin = <#user_type as crate::api::auth::LoggedInUser>::is_admin(&user);
if !is_admin {
return Err(crate::util::error::BackendError::AuthError(
crate::util::error::AuthError::AdminRequired));
}
});
}
if flags.login {
pre_compute.push(quote! {
let is_logged_in = <#user_type as crate::api::auth::LoggedInUser>::is_logged_in(&user);
if !is_logged_in {
return Err(crate::util::error::BackendError::AuthError(
crate::util::error::AuthError::Unauthorized));
}
});
}
pre_compute
}
fn build_server_macro(flags: &Flags, attr_args: &Punctuated<Meta, Token![,]>) -> TokenStream2 {
// Remove flags from server args
let server_args: Punctuated<Meta, Token![,]> = attr_args
.clone()
.into_iter()
.filter(|arg| {
match arg {
Meta::Path(path) => {
!path.is_ident("upload") &&
!path.is_ident("admin") &&
!path.is_ident("login")
}
_ => true,
}
})
.collect();
if flags.upload {
quote! {
#[server(input = server_fn::codec::MultipartFormData, #server_args)]
}
} else {
quote! {
#[server(client = crate::util::serverfn_client::Client, #server_args)]
}
}
}

123
src/db_type.rs Normal file
View File

@@ -0,0 +1,123 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed, Path, parse_macro_input};
pub fn db_type_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let derive_input = parse_macro_input!(item as DeriveInput);
// Get struct
let original_struct = if let Data::Struct(data_struct) = derive_input.data.clone() {
data_struct
} else {
return syn::Error::new_spanned(
derive_input.ident,
"#[db_type] can only be applied to structs",
)
.to_compile_error()
.into();
};
// Get fields
let fields = if let Fields::Named(fields) = &original_struct.fields {
fields.clone()
} else {
return syn::Error::new_spanned(
derive_input.ident,
"#[db_type] can only be applied to structs with named fields",
)
.to_compile_error()
.into();
};
let has_id_field = fields.named.iter().any(|field|
field.ident.as_ref()
.map_or(false, |field| field == "id")
);
// Filter out fields with the `#[omit_new]` attribute
let new_fields = fields.named.iter().filter(|field| {
!field
.attrs
.iter()
.any(|attr| attr.path().is_ident("omit_new"))
});
// Wrap back into FieldsNamed
let new_fields = Fields::Named(FieldsNamed {
brace_token: fields.brace_token,
named: new_fields.cloned().collect(),
});
// Create a new struct with the filtered fields
let new_struct = DataStruct {
fields: new_fields,
..original_struct.clone()
};
// Create a new DeriveInput with the new struct, named "New<OriginalName>"
let new_name = format!("New{}", derive_input.ident.clone());
let new_derive_input = DeriveInput {
ident: syn::Ident::new(&new_name, derive_input.ident.span()),
data: Data::Struct(new_struct),
..derive_input.clone()
};
// Remove `#[omit_new]` attributes from the original struct
let clean_original_fields = fields.named.iter().map(|field| {
let mut new_field = field.clone();
new_field
.attrs
.retain(|attr| !attr.path().is_ident("omit_new"));
new_field
});
// Wrap back into FieldsNamed
let clean_original_fields = Fields::Named(FieldsNamed {
brace_token: fields.brace_token,
named: clean_original_fields.collect(),
});
// Create a new struct with the cleaned fields
let clean_original_struct = DataStruct {
fields: clean_original_fields,
..original_struct.clone()
};
// Create a new DeriveInput with the cleaned struct
let clean_derive_input = DeriveInput {
ident: derive_input.ident.clone(),
data: Data::Struct(clean_original_struct),
..derive_input.clone()
};
// Get the table path from the attribute
let table_path = parse_macro_input!(attr as Path);
// Don't include the Identifiable derive for structs without an `id` field
let full_type_derives = if has_id_field {
quote! {
diesel::prelude::Queryable,
diesel::prelude::Selectable,
diesel::prelude::Identifiable
}
} else {
quote! {
diesel::prelude::Queryable,
diesel::prelude::Selectable
}
};
// Generate the expanded code
let expanded = quote! {
#[cfg_attr(feature = "ssr", derive(#full_type_derives))]
#[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))]
#[cfg_attr(feature = "ssr", diesel(table_name = #table_path))]
#clean_derive_input
#[cfg_attr(feature = "ssr", derive(diesel::prelude::Insertable))]
#[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))]
#[cfg_attr(feature = "ssr", diesel(table_name = #table_path))]
#new_derive_input
};
expanded.into()
}

View File

@@ -1,124 +1,14 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed, Path, parse_macro_input}; mod db_type;
mod api_fn;
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream { pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream {
let derive_input = parse_macro_input!(item as DeriveInput); db_type::db_type_impl(attr, item)
}
// Get struct
let original_struct = if let Data::Struct(data_struct) = derive_input.data.clone() { #[proc_macro_attribute]
data_struct pub fn api_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
} else { api_fn::api_fn_impl(attr, item)
return syn::Error::new_spanned(
derive_input.ident,
"#[db_type] can only be applied to structs",
)
.to_compile_error()
.into();
};
// Get fields
let fields = if let Fields::Named(fields) = &original_struct.fields {
fields.clone()
} else {
return syn::Error::new_spanned(
derive_input.ident,
"#[db_type] can only be applied to structs with named fields",
)
.to_compile_error()
.into();
};
let has_id_field = fields.named.iter().any(|field|
field.ident.as_ref()
.map_or(false, |field| field == "id")
);
// Filter out fields with the `#[omit_new]` attribute
let new_fields = fields.named.iter().filter(|field| {
!field
.attrs
.iter()
.any(|attr| attr.path().is_ident("omit_new"))
});
// Wrap back into FieldsNamed
let new_fields = Fields::Named(FieldsNamed {
brace_token: fields.brace_token,
named: new_fields.cloned().collect(),
});
// Create a new struct with the filtered fields
let new_struct = DataStruct {
fields: new_fields,
..original_struct.clone()
};
// Create a new DeriveInput with the new struct, named "New<OriginalName>"
let new_name = format!("New{}", derive_input.ident.clone());
let new_derive_input = DeriveInput {
ident: syn::Ident::new(&new_name, derive_input.ident.span()),
data: Data::Struct(new_struct),
..derive_input.clone()
};
// Remove `#[omit_new]` attributes from the original struct
let clean_original_fields = fields.named.iter().map(|field| {
let mut new_field = field.clone();
new_field
.attrs
.retain(|attr| !attr.path().is_ident("omit_new"));
new_field
});
// Wrap back into FieldsNamed
let clean_original_fields = Fields::Named(FieldsNamed {
brace_token: fields.brace_token,
named: clean_original_fields.collect(),
});
// Create a new struct with the cleaned fields
let clean_original_struct = DataStruct {
fields: clean_original_fields,
..original_struct.clone()
};
// Create a new DeriveInput with the cleaned struct
let clean_derive_input = DeriveInput {
ident: derive_input.ident.clone(),
data: Data::Struct(clean_original_struct),
..derive_input.clone()
};
// Get the table path from the attribute
let table_path = parse_macro_input!(attr as Path);
// Don't include the Identifiable derive for structs without an `id` field
let full_type_derives = if has_id_field {
quote! {
diesel::prelude::Queryable,
diesel::prelude::Selectable,
diesel::prelude::Identifiable
}
} else {
quote! {
diesel::prelude::Queryable,
diesel::prelude::Selectable
}
};
// Generate the expanded code
let expanded = quote! {
#[cfg_attr(feature = "ssr", derive(#full_type_derives))]
#[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))]
#[cfg_attr(feature = "ssr", diesel(table_name = #table_path))]
#clean_derive_input
#[cfg_attr(feature = "ssr", derive(diesel::prelude::Insertable))]
#[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))]
#[cfg_attr(feature = "ssr", diesel(table_name = #table_path))]
#new_derive_input
};
expanded.into()
} }