Add api_fn macro
This commit is contained in:
253
src/api_fn.rs
Normal file
253
src/api_fn.rs
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
use proc_macro::TokenStream;
|
||||||
|
use proc_macro2::TokenStream as TokenStream2;
|
||||||
|
use quote::{format_ident, quote};
|
||||||
|
use syn::{
|
||||||
|
*,
|
||||||
|
punctuated::Punctuated,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Flags {
|
||||||
|
upload: bool,
|
||||||
|
admin: bool,
|
||||||
|
login: bool,
|
||||||
|
endpoint: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct FnArgFlags {
|
||||||
|
need_db_conn: bool,
|
||||||
|
need_user: bool,
|
||||||
|
user_type: Option<Box<Type>>,
|
||||||
|
need_state: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn api_fn_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||||
|
let attr_args = parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
|
||||||
|
let mut input_fn = parse_macro_input!(item as ItemFn);
|
||||||
|
|
||||||
|
let flags = match parse_attributes(&attr_args) {
|
||||||
|
Ok(flags) => flags,
|
||||||
|
Err(e) => return e.to_compile_error().into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get function signature and original name
|
||||||
|
let sig = &input_fn.sig;
|
||||||
|
let orig_name = &sig.ident;
|
||||||
|
let inner_name = format_ident!("{}_inner", orig_name);
|
||||||
|
|
||||||
|
// Process function inputs and identify special parameters
|
||||||
|
let (outer_inputs, fn_arg_flags, call_args) = match process_function_inputs(sig, &flags) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => return e.to_compile_error().into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate pre-processing code
|
||||||
|
let pre_compute = generate_precompute_code(&flags, &fn_arg_flags);
|
||||||
|
|
||||||
|
// Build outer signature
|
||||||
|
let mut outer_sig = sig.clone();
|
||||||
|
outer_sig.inputs = outer_inputs;
|
||||||
|
|
||||||
|
// Rename original function
|
||||||
|
input_fn.sig.ident = inner_name.clone();
|
||||||
|
|
||||||
|
// Build server macro
|
||||||
|
let server_macro = build_server_macro(&flags, &attr_args);
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
#[cfg(feature = "ssr")]
|
||||||
|
#input_fn
|
||||||
|
|
||||||
|
#server_macro
|
||||||
|
pub #outer_sig {
|
||||||
|
#(#pre_compute)*
|
||||||
|
|
||||||
|
#inner_name(#call_args).await
|
||||||
|
}
|
||||||
|
}.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse attributes to `api_fn` into the `Flags` struct
|
||||||
|
fn parse_attributes(attr_args: &Punctuated<Meta, Token![,]>) -> Result<Flags> {
|
||||||
|
let mut endpoint = None;
|
||||||
|
|
||||||
|
let mut upload_flag = false;
|
||||||
|
let mut admin_flag = false;
|
||||||
|
let mut login_flag = false;
|
||||||
|
|
||||||
|
for arg in attr_args {
|
||||||
|
match arg {
|
||||||
|
Meta::NameValue(name_value) => {
|
||||||
|
if name_value.path.is_ident("endpoint") {
|
||||||
|
match &name_value.value {
|
||||||
|
Expr::Lit(expr_lit) => {
|
||||||
|
if let Lit::Str(lit_str) = &expr_lit.lit {
|
||||||
|
endpoint = Some(lit_str.value());
|
||||||
|
} else {
|
||||||
|
return Err(Error::new_spanned(&name_value.value, "endpoint must be a string literal"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(Error::new_spanned(&name_value.value, "Invalid value for endpoint")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Meta::Path(path) => {
|
||||||
|
if path.is_ident("upload") {
|
||||||
|
upload_flag = true;
|
||||||
|
} else if path.is_ident("admin") {
|
||||||
|
admin_flag = true;
|
||||||
|
} else if path.is_ident("login") {
|
||||||
|
login_flag = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let endpoint = endpoint.ok_or_else(|| Error::new_spanned(attr_args, "Must specify API endpoint"))?;
|
||||||
|
|
||||||
|
Ok(Flags {
|
||||||
|
upload: upload_flag,
|
||||||
|
admin: admin_flag,
|
||||||
|
login: login_flag,
|
||||||
|
endpoint
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate inputs to outer function, flags for preprocessing, and inner function call arguments
|
||||||
|
fn process_function_inputs(sig: &Signature, flags: &Flags) -> Result<(Punctuated<FnArg, Token![,]>, FnArgFlags, TokenStream2)> {
|
||||||
|
let mut outer_inputs = Punctuated::new();
|
||||||
|
let mut call_args = vec![];
|
||||||
|
let mut need_db_conn = false;
|
||||||
|
let mut need_user = flags.admin || flags.login;
|
||||||
|
let mut need_state = false;
|
||||||
|
let mut user_type = None;
|
||||||
|
|
||||||
|
for arg in &sig.inputs {
|
||||||
|
match arg {
|
||||||
|
FnArg::Typed(pat_type) => {
|
||||||
|
if let Pat::Ident(PatIdent { ident, .. }) = &*pat_type.pat {
|
||||||
|
let name = ident.to_string();
|
||||||
|
|
||||||
|
if name == "db_conn" {
|
||||||
|
need_db_conn = true;
|
||||||
|
need_state = true;
|
||||||
|
} else if name == "user" {
|
||||||
|
need_user = true;
|
||||||
|
user_type = Some(pat_type.ty.clone());
|
||||||
|
} else if name == "state" {
|
||||||
|
need_state = true;
|
||||||
|
} else {
|
||||||
|
outer_inputs.push(arg.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
call_args.push(quote! { #ident });
|
||||||
|
} else {
|
||||||
|
// Handle non-ident patterns
|
||||||
|
outer_inputs.push(arg.clone());
|
||||||
|
|
||||||
|
call_args.push(quote! { #arg });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FnArg::Receiver(_) => {
|
||||||
|
outer_inputs.push(arg.clone());
|
||||||
|
call_args.push(quote! { self });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let fn_args_flags = FnArgFlags {
|
||||||
|
need_db_conn,
|
||||||
|
need_user,
|
||||||
|
user_type,
|
||||||
|
need_state,
|
||||||
|
};
|
||||||
|
|
||||||
|
let call_args_stream = quote! { #(#call_args),* };
|
||||||
|
|
||||||
|
Ok((outer_inputs, fn_args_flags, call_args_stream))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_precompute_code(flags: &Flags, fn_args: &FnArgFlags) -> Vec<proc_macro2::TokenStream> {
|
||||||
|
let mut pre_compute = Vec::new();
|
||||||
|
|
||||||
|
if fn_args.need_state {
|
||||||
|
let state_err_msg = format!("Failed to get backend state for endpoint `{}`", flags.endpoint);
|
||||||
|
|
||||||
|
pre_compute.push(quote! {
|
||||||
|
let state_result = crate::util::backend_state::BackendState::get().await;
|
||||||
|
let state = crate::util::error::Contextualize::context(state_result, #state_err_msg)?;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if fn_args.need_db_conn {
|
||||||
|
let db_conn_err_msg = format!("Failed to get database connection for endpoint `{}`", flags.endpoint);
|
||||||
|
|
||||||
|
pre_compute.push(quote! {
|
||||||
|
let mut db_conn_result = state.get_db_conn();
|
||||||
|
let mut db_conn = crate::util::error::Contextualize::context(db_conn_result, #db_conn_err_msg)?;
|
||||||
|
let db_conn = &mut db_conn;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let user_type = if let Some(ty) = &fn_args.user_type {
|
||||||
|
quote! { #ty }
|
||||||
|
} else {
|
||||||
|
quote! { Option<crate::models::backend::User> }
|
||||||
|
};
|
||||||
|
|
||||||
|
if fn_args.need_user {
|
||||||
|
pre_compute.push(quote! {
|
||||||
|
let user = <#user_type as crate::api::auth::LoggedInUser>::get().await?;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if flags.admin {
|
||||||
|
pre_compute.push(quote! {
|
||||||
|
let is_admin = <#user_type as crate::api::auth::LoggedInUser>::is_admin(&user);
|
||||||
|
if !is_admin {
|
||||||
|
return Err(crate::util::error::BackendError::AuthError(
|
||||||
|
crate::util::error::AuthError::AdminRequired));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if flags.login {
|
||||||
|
pre_compute.push(quote! {
|
||||||
|
let is_logged_in = <#user_type as crate::api::auth::LoggedInUser>::is_logged_in(&user);
|
||||||
|
if !is_logged_in {
|
||||||
|
return Err(crate::util::error::BackendError::AuthError(
|
||||||
|
crate::util::error::AuthError::Unauthorized));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pre_compute
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_server_macro(flags: &Flags, attr_args: &Punctuated<Meta, Token![,]>) -> TokenStream2 {
|
||||||
|
// Remove flags from server args
|
||||||
|
let server_args: Punctuated<Meta, Token![,]> = attr_args
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.filter(|arg| {
|
||||||
|
match arg {
|
||||||
|
Meta::Path(path) => {
|
||||||
|
!path.is_ident("upload") &&
|
||||||
|
!path.is_ident("admin") &&
|
||||||
|
!path.is_ident("login")
|
||||||
|
}
|
||||||
|
_ => true,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if flags.upload {
|
||||||
|
quote! {
|
||||||
|
#[server(input = server_fn::codec::MultipartFormData, #server_args)]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {
|
||||||
|
#[server(client = crate::util::serverfn_client::Client, #server_args)]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,14 @@
|
|||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
|
|
||||||
mod db_type;
|
mod db_type;
|
||||||
|
mod api_fn;
|
||||||
|
|
||||||
#[proc_macro_attribute]
|
#[proc_macro_attribute]
|
||||||
pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream {
|
pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||||
db_type::db_type_impl(attr, item)
|
db_type::db_type_impl(attr, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[proc_macro_attribute]
|
||||||
|
pub fn api_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||||
|
api_fn::api_fn_impl(attr, item)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user