diff --git a/src/api_fn.rs b/src/api_fn.rs new file mode 100644 index 0000000..0ce57c8 --- /dev/null +++ b/src/api_fn.rs @@ -0,0 +1,253 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote}; +use syn::{ + *, + punctuated::Punctuated, +}; + +struct Flags { + upload: bool, + admin: bool, + login: bool, + endpoint: String, +} + +struct FnArgFlags { + need_db_conn: bool, + need_user: bool, + user_type: Option>, + need_state: bool, +} + +pub fn api_fn_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let attr_args = parse_macro_input!(attr with Punctuated::::parse_terminated); + let mut input_fn = parse_macro_input!(item as ItemFn); + + let flags = match parse_attributes(&attr_args) { + Ok(flags) => flags, + Err(e) => return e.to_compile_error().into(), + }; + + // Get function signature and original name + let sig = &input_fn.sig; + let orig_name = &sig.ident; + let inner_name = format_ident!("{}_inner", orig_name); + + // Process function inputs and identify special parameters + let (outer_inputs, fn_arg_flags, call_args) = match process_function_inputs(sig, &flags) { + Ok(v) => v, + Err(e) => return e.to_compile_error().into(), + }; + + // Generate pre-processing code + let pre_compute = generate_precompute_code(&flags, &fn_arg_flags); + + // Build outer signature + let mut outer_sig = sig.clone(); + outer_sig.inputs = outer_inputs; + + // Rename original function + input_fn.sig.ident = inner_name.clone(); + + // Build server macro + let server_macro = build_server_macro(&flags, &attr_args); + + quote! { + #[cfg(feature = "ssr")] + #input_fn + + #server_macro + pub #outer_sig { + #(#pre_compute)* + + #inner_name(#call_args).await + } + }.into() +} + +/// Parse attributes to `api_fn` into the `Flags` struct +fn parse_attributes(attr_args: &Punctuated) -> Result { + let mut endpoint = None; + + let mut upload_flag = false; + let mut admin_flag = false; + let mut login_flag = false; + + for arg in attr_args { + match arg { + Meta::NameValue(name_value) => { + if name_value.path.is_ident("endpoint") { + match &name_value.value { + Expr::Lit(expr_lit) => { + if let Lit::Str(lit_str) = &expr_lit.lit { + endpoint = Some(lit_str.value()); + } else { + return Err(Error::new_spanned(&name_value.value, "endpoint must be a string literal")); + } + } + _ => return Err(Error::new_spanned(&name_value.value, "Invalid value for endpoint")), + } + } + } + Meta::Path(path) => { + if path.is_ident("upload") { + upload_flag = true; + } else if path.is_ident("admin") { + admin_flag = true; + } else if path.is_ident("login") { + login_flag = true; + } + } + _ => {} + } + } + + let endpoint = endpoint.ok_or_else(|| Error::new_spanned(attr_args, "Must specify API endpoint"))?; + + Ok(Flags { + upload: upload_flag, + admin: admin_flag, + login: login_flag, + endpoint + }) +} + +/// Generate inputs to outer function, flags for preprocessing, and inner function call arguments +fn process_function_inputs(sig: &Signature, flags: &Flags) -> Result<(Punctuated, FnArgFlags, TokenStream2)> { + let mut outer_inputs = Punctuated::new(); + let mut call_args = vec![]; + let mut need_db_conn = false; + let mut need_user = flags.admin || flags.login; + let mut need_state = false; + let mut user_type = None; + + for arg in &sig.inputs { + match arg { + FnArg::Typed(pat_type) => { + if let Pat::Ident(PatIdent { ident, .. }) = &*pat_type.pat { + let name = ident.to_string(); + + if name == "db_conn" { + need_db_conn = true; + need_state = true; + } else if name == "user" { + need_user = true; + user_type = Some(pat_type.ty.clone()); + } else if name == "state" { + need_state = true; + } else { + outer_inputs.push(arg.clone()); + } + + call_args.push(quote! { #ident }); + } else { + // Handle non-ident patterns + outer_inputs.push(arg.clone()); + + call_args.push(quote! { #arg }); + } + } + FnArg::Receiver(_) => { + outer_inputs.push(arg.clone()); + call_args.push(quote! { self }); + } + } + } + + let fn_args_flags = FnArgFlags { + need_db_conn, + need_user, + user_type, + need_state, + }; + + let call_args_stream = quote! { #(#call_args),* }; + + Ok((outer_inputs, fn_args_flags, call_args_stream)) +} + +fn generate_precompute_code(flags: &Flags, fn_args: &FnArgFlags) -> Vec { + let mut pre_compute = Vec::new(); + + if fn_args.need_state { + let state_err_msg = format!("Failed to get backend state for endpoint `{}`", flags.endpoint); + + pre_compute.push(quote! { + let state_result = crate::util::backend_state::BackendState::get().await; + let state = crate::util::error::Contextualize::context(state_result, #state_err_msg)?; + }); + } + + if fn_args.need_db_conn { + let db_conn_err_msg = format!("Failed to get database connection for endpoint `{}`", flags.endpoint); + + pre_compute.push(quote! { + let mut db_conn_result = state.get_db_conn(); + let mut db_conn = crate::util::error::Contextualize::context(db_conn_result, #db_conn_err_msg)?; + let db_conn = &mut db_conn; + }); + } + + let user_type = if let Some(ty) = &fn_args.user_type { + quote! { #ty } + } else { + quote! { Option } + }; + + if fn_args.need_user { + pre_compute.push(quote! { + let user = <#user_type as crate::api::auth::LoggedInUser>::get().await?; + }); + } + + if flags.admin { + pre_compute.push(quote! { + let is_admin = <#user_type as crate::api::auth::LoggedInUser>::is_admin(&user); + if !is_admin { + return Err(crate::util::error::BackendError::AuthError( + crate::util::error::AuthError::AdminRequired)); + } + }); + } + + if flags.login { + pre_compute.push(quote! { + let is_logged_in = <#user_type as crate::api::auth::LoggedInUser>::is_logged_in(&user); + if !is_logged_in { + return Err(crate::util::error::BackendError::AuthError( + crate::util::error::AuthError::Unauthorized)); + } + }); + } + + pre_compute +} + +fn build_server_macro(flags: &Flags, attr_args: &Punctuated) -> TokenStream2 { + // Remove flags from server args + let server_args: Punctuated = attr_args + .clone() + .into_iter() + .filter(|arg| { + match arg { + Meta::Path(path) => { + !path.is_ident("upload") && + !path.is_ident("admin") && + !path.is_ident("login") + } + _ => true, + } + }) + .collect(); + + if flags.upload { + quote! { + #[server(input = server_fn::codec::MultipartFormData, #server_args)] + } + } else { + quote! { + #[server(client = crate::util::serverfn_client::Client, #server_args)] + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 2a894d8..f6758ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,14 @@ use proc_macro::TokenStream; mod db_type; +mod api_fn; #[proc_macro_attribute] pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream { db_type::db_type_impl(attr, item) } + +#[proc_macro_attribute] +pub fn api_fn(attr: TokenStream, item: TokenStream) -> TokenStream { + api_fn::api_fn_impl(attr, item) +}