Add api_fn macro

This commit is contained in:
2025-10-20 18:11:09 -04:00
parent c889480913
commit 34726d4f11
2 changed files with 259 additions and 0 deletions

253
src/api_fn.rs Normal file
View File

@@ -0,0 +1,253 @@
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{
*,
punctuated::Punctuated,
};
struct Flags {
upload: bool,
admin: bool,
login: bool,
endpoint: String,
}
struct FnArgFlags {
need_db_conn: bool,
need_user: bool,
user_type: Option<Box<Type>>,
need_state: bool,
}
pub fn api_fn_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_args = parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
let mut input_fn = parse_macro_input!(item as ItemFn);
let flags = match parse_attributes(&attr_args) {
Ok(flags) => flags,
Err(e) => return e.to_compile_error().into(),
};
// Get function signature and original name
let sig = &input_fn.sig;
let orig_name = &sig.ident;
let inner_name = format_ident!("{}_inner", orig_name);
// Process function inputs and identify special parameters
let (outer_inputs, fn_arg_flags, call_args) = match process_function_inputs(sig, &flags) {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
// Generate pre-processing code
let pre_compute = generate_precompute_code(&flags, &fn_arg_flags);
// Build outer signature
let mut outer_sig = sig.clone();
outer_sig.inputs = outer_inputs;
// Rename original function
input_fn.sig.ident = inner_name.clone();
// Build server macro
let server_macro = build_server_macro(&flags, &attr_args);
quote! {
#[cfg(feature = "ssr")]
#input_fn
#server_macro
pub #outer_sig {
#(#pre_compute)*
#inner_name(#call_args).await
}
}.into()
}
/// Parse attributes to `api_fn` into the `Flags` struct
fn parse_attributes(attr_args: &Punctuated<Meta, Token![,]>) -> Result<Flags> {
let mut endpoint = None;
let mut upload_flag = false;
let mut admin_flag = false;
let mut login_flag = false;
for arg in attr_args {
match arg {
Meta::NameValue(name_value) => {
if name_value.path.is_ident("endpoint") {
match &name_value.value {
Expr::Lit(expr_lit) => {
if let Lit::Str(lit_str) = &expr_lit.lit {
endpoint = Some(lit_str.value());
} else {
return Err(Error::new_spanned(&name_value.value, "endpoint must be a string literal"));
}
}
_ => return Err(Error::new_spanned(&name_value.value, "Invalid value for endpoint")),
}
}
}
Meta::Path(path) => {
if path.is_ident("upload") {
upload_flag = true;
} else if path.is_ident("admin") {
admin_flag = true;
} else if path.is_ident("login") {
login_flag = true;
}
}
_ => {}
}
}
let endpoint = endpoint.ok_or_else(|| Error::new_spanned(attr_args, "Must specify API endpoint"))?;
Ok(Flags {
upload: upload_flag,
admin: admin_flag,
login: login_flag,
endpoint
})
}
/// Generate inputs to outer function, flags for preprocessing, and inner function call arguments
fn process_function_inputs(sig: &Signature, flags: &Flags) -> Result<(Punctuated<FnArg, Token![,]>, FnArgFlags, TokenStream2)> {
let mut outer_inputs = Punctuated::new();
let mut call_args = vec![];
let mut need_db_conn = false;
let mut need_user = flags.admin || flags.login;
let mut need_state = false;
let mut user_type = None;
for arg in &sig.inputs {
match arg {
FnArg::Typed(pat_type) => {
if let Pat::Ident(PatIdent { ident, .. }) = &*pat_type.pat {
let name = ident.to_string();
if name == "db_conn" {
need_db_conn = true;
need_state = true;
} else if name == "user" {
need_user = true;
user_type = Some(pat_type.ty.clone());
} else if name == "state" {
need_state = true;
} else {
outer_inputs.push(arg.clone());
}
call_args.push(quote! { #ident });
} else {
// Handle non-ident patterns
outer_inputs.push(arg.clone());
call_args.push(quote! { #arg });
}
}
FnArg::Receiver(_) => {
outer_inputs.push(arg.clone());
call_args.push(quote! { self });
}
}
}
let fn_args_flags = FnArgFlags {
need_db_conn,
need_user,
user_type,
need_state,
};
let call_args_stream = quote! { #(#call_args),* };
Ok((outer_inputs, fn_args_flags, call_args_stream))
}
fn generate_precompute_code(flags: &Flags, fn_args: &FnArgFlags) -> Vec<proc_macro2::TokenStream> {
let mut pre_compute = Vec::new();
if fn_args.need_state {
let state_err_msg = format!("Failed to get backend state for endpoint `{}`", flags.endpoint);
pre_compute.push(quote! {
let state_result = crate::util::backend_state::BackendState::get().await;
let state = crate::util::error::Contextualize::context(state_result, #state_err_msg)?;
});
}
if fn_args.need_db_conn {
let db_conn_err_msg = format!("Failed to get database connection for endpoint `{}`", flags.endpoint);
pre_compute.push(quote! {
let mut db_conn_result = state.get_db_conn();
let mut db_conn = crate::util::error::Contextualize::context(db_conn_result, #db_conn_err_msg)?;
let db_conn = &mut db_conn;
});
}
let user_type = if let Some(ty) = &fn_args.user_type {
quote! { #ty }
} else {
quote! { Option<crate::models::backend::User> }
};
if fn_args.need_user {
pre_compute.push(quote! {
let user = <#user_type as crate::api::auth::LoggedInUser>::get().await?;
});
}
if flags.admin {
pre_compute.push(quote! {
let is_admin = <#user_type as crate::api::auth::LoggedInUser>::is_admin(&user);
if !is_admin {
return Err(crate::util::error::BackendError::AuthError(
crate::util::error::AuthError::AdminRequired));
}
});
}
if flags.login {
pre_compute.push(quote! {
let is_logged_in = <#user_type as crate::api::auth::LoggedInUser>::is_logged_in(&user);
if !is_logged_in {
return Err(crate::util::error::BackendError::AuthError(
crate::util::error::AuthError::Unauthorized));
}
});
}
pre_compute
}
fn build_server_macro(flags: &Flags, attr_args: &Punctuated<Meta, Token![,]>) -> TokenStream2 {
// Remove flags from server args
let server_args: Punctuated<Meta, Token![,]> = attr_args
.clone()
.into_iter()
.filter(|arg| {
match arg {
Meta::Path(path) => {
!path.is_ident("upload") &&
!path.is_ident("admin") &&
!path.is_ident("login")
}
_ => true,
}
})
.collect();
if flags.upload {
quote! {
#[server(input = server_fn::codec::MultipartFormData, #server_args)]
}
} else {
quote! {
#[server(client = crate::util::serverfn_client::Client, #server_args)]
}
}
}

View File

@@ -1,8 +1,14 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
mod db_type; mod db_type;
mod api_fn;
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream { pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream {
db_type::db_type_impl(attr, item) db_type::db_type_impl(attr, item)
} }
#[proc_macro_attribute]
pub fn api_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
api_fn::api_fn_impl(attr, item)
}