Compare commits
3 Commits
af30c4b5c9
...
34726d4f11
| Author | SHA1 | Date | |
|---|---|---|---|
|
34726d4f11
|
|||
|
c889480913
|
|||
|
3466212de7
|
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -6,15 +6,16 @@ version = 4
|
||||
name = "libretunes_macro"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.95"
|
||||
version = "1.0.101"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778"
|
||||
checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
@@ -7,5 +7,6 @@ edition = "2024"
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
proc-macro2 = "1.0.101"
|
||||
quote = "1.0.40"
|
||||
syn = { version = "2.0.101", features = ["full"] }
|
||||
|
||||
253
src/api_fn.rs
Normal file
253
src/api_fn.rs
Normal file
@@ -0,0 +1,253 @@
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro2::TokenStream as TokenStream2;
|
||||
use quote::{format_ident, quote};
|
||||
use syn::{
|
||||
*,
|
||||
punctuated::Punctuated,
|
||||
};
|
||||
|
||||
struct Flags {
|
||||
upload: bool,
|
||||
admin: bool,
|
||||
login: bool,
|
||||
endpoint: String,
|
||||
}
|
||||
|
||||
struct FnArgFlags {
|
||||
need_db_conn: bool,
|
||||
need_user: bool,
|
||||
user_type: Option<Box<Type>>,
|
||||
need_state: bool,
|
||||
}
|
||||
|
||||
pub fn api_fn_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let attr_args = parse_macro_input!(attr with Punctuated::<Meta, Token![,]>::parse_terminated);
|
||||
let mut input_fn = parse_macro_input!(item as ItemFn);
|
||||
|
||||
let flags = match parse_attributes(&attr_args) {
|
||||
Ok(flags) => flags,
|
||||
Err(e) => return e.to_compile_error().into(),
|
||||
};
|
||||
|
||||
// Get function signature and original name
|
||||
let sig = &input_fn.sig;
|
||||
let orig_name = &sig.ident;
|
||||
let inner_name = format_ident!("{}_inner", orig_name);
|
||||
|
||||
// Process function inputs and identify special parameters
|
||||
let (outer_inputs, fn_arg_flags, call_args) = match process_function_inputs(sig, &flags) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return e.to_compile_error().into(),
|
||||
};
|
||||
|
||||
// Generate pre-processing code
|
||||
let pre_compute = generate_precompute_code(&flags, &fn_arg_flags);
|
||||
|
||||
// Build outer signature
|
||||
let mut outer_sig = sig.clone();
|
||||
outer_sig.inputs = outer_inputs;
|
||||
|
||||
// Rename original function
|
||||
input_fn.sig.ident = inner_name.clone();
|
||||
|
||||
// Build server macro
|
||||
let server_macro = build_server_macro(&flags, &attr_args);
|
||||
|
||||
quote! {
|
||||
#[cfg(feature = "ssr")]
|
||||
#input_fn
|
||||
|
||||
#server_macro
|
||||
pub #outer_sig {
|
||||
#(#pre_compute)*
|
||||
|
||||
#inner_name(#call_args).await
|
||||
}
|
||||
}.into()
|
||||
}
|
||||
|
||||
/// Parse attributes to `api_fn` into the `Flags` struct
|
||||
fn parse_attributes(attr_args: &Punctuated<Meta, Token![,]>) -> Result<Flags> {
|
||||
let mut endpoint = None;
|
||||
|
||||
let mut upload_flag = false;
|
||||
let mut admin_flag = false;
|
||||
let mut login_flag = false;
|
||||
|
||||
for arg in attr_args {
|
||||
match arg {
|
||||
Meta::NameValue(name_value) => {
|
||||
if name_value.path.is_ident("endpoint") {
|
||||
match &name_value.value {
|
||||
Expr::Lit(expr_lit) => {
|
||||
if let Lit::Str(lit_str) = &expr_lit.lit {
|
||||
endpoint = Some(lit_str.value());
|
||||
} else {
|
||||
return Err(Error::new_spanned(&name_value.value, "endpoint must be a string literal"));
|
||||
}
|
||||
}
|
||||
_ => return Err(Error::new_spanned(&name_value.value, "Invalid value for endpoint")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Meta::Path(path) => {
|
||||
if path.is_ident("upload") {
|
||||
upload_flag = true;
|
||||
} else if path.is_ident("admin") {
|
||||
admin_flag = true;
|
||||
} else if path.is_ident("login") {
|
||||
login_flag = true;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let endpoint = endpoint.ok_or_else(|| Error::new_spanned(attr_args, "Must specify API endpoint"))?;
|
||||
|
||||
Ok(Flags {
|
||||
upload: upload_flag,
|
||||
admin: admin_flag,
|
||||
login: login_flag,
|
||||
endpoint
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate inputs to outer function, flags for preprocessing, and inner function call arguments
|
||||
fn process_function_inputs(sig: &Signature, flags: &Flags) -> Result<(Punctuated<FnArg, Token![,]>, FnArgFlags, TokenStream2)> {
|
||||
let mut outer_inputs = Punctuated::new();
|
||||
let mut call_args = vec![];
|
||||
let mut need_db_conn = false;
|
||||
let mut need_user = flags.admin || flags.login;
|
||||
let mut need_state = false;
|
||||
let mut user_type = None;
|
||||
|
||||
for arg in &sig.inputs {
|
||||
match arg {
|
||||
FnArg::Typed(pat_type) => {
|
||||
if let Pat::Ident(PatIdent { ident, .. }) = &*pat_type.pat {
|
||||
let name = ident.to_string();
|
||||
|
||||
if name == "db_conn" {
|
||||
need_db_conn = true;
|
||||
need_state = true;
|
||||
} else if name == "user" {
|
||||
need_user = true;
|
||||
user_type = Some(pat_type.ty.clone());
|
||||
} else if name == "state" {
|
||||
need_state = true;
|
||||
} else {
|
||||
outer_inputs.push(arg.clone());
|
||||
}
|
||||
|
||||
call_args.push(quote! { #ident });
|
||||
} else {
|
||||
// Handle non-ident patterns
|
||||
outer_inputs.push(arg.clone());
|
||||
|
||||
call_args.push(quote! { #arg });
|
||||
}
|
||||
}
|
||||
FnArg::Receiver(_) => {
|
||||
outer_inputs.push(arg.clone());
|
||||
call_args.push(quote! { self });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let fn_args_flags = FnArgFlags {
|
||||
need_db_conn,
|
||||
need_user,
|
||||
user_type,
|
||||
need_state,
|
||||
};
|
||||
|
||||
let call_args_stream = quote! { #(#call_args),* };
|
||||
|
||||
Ok((outer_inputs, fn_args_flags, call_args_stream))
|
||||
}
|
||||
|
||||
fn generate_precompute_code(flags: &Flags, fn_args: &FnArgFlags) -> Vec<proc_macro2::TokenStream> {
|
||||
let mut pre_compute = Vec::new();
|
||||
|
||||
if fn_args.need_state {
|
||||
let state_err_msg = format!("Failed to get backend state for endpoint `{}`", flags.endpoint);
|
||||
|
||||
pre_compute.push(quote! {
|
||||
let state_result = crate::util::backend_state::BackendState::get().await;
|
||||
let state = crate::util::error::Contextualize::context(state_result, #state_err_msg)?;
|
||||
});
|
||||
}
|
||||
|
||||
if fn_args.need_db_conn {
|
||||
let db_conn_err_msg = format!("Failed to get database connection for endpoint `{}`", flags.endpoint);
|
||||
|
||||
pre_compute.push(quote! {
|
||||
let mut db_conn_result = state.get_db_conn();
|
||||
let mut db_conn = crate::util::error::Contextualize::context(db_conn_result, #db_conn_err_msg)?;
|
||||
let db_conn = &mut db_conn;
|
||||
});
|
||||
}
|
||||
|
||||
let user_type = if let Some(ty) = &fn_args.user_type {
|
||||
quote! { #ty }
|
||||
} else {
|
||||
quote! { Option<crate::models::backend::User> }
|
||||
};
|
||||
|
||||
if fn_args.need_user {
|
||||
pre_compute.push(quote! {
|
||||
let user = <#user_type as crate::api::auth::LoggedInUser>::get().await?;
|
||||
});
|
||||
}
|
||||
|
||||
if flags.admin {
|
||||
pre_compute.push(quote! {
|
||||
let is_admin = <#user_type as crate::api::auth::LoggedInUser>::is_admin(&user);
|
||||
if !is_admin {
|
||||
return Err(crate::util::error::BackendError::AuthError(
|
||||
crate::util::error::AuthError::AdminRequired));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if flags.login {
|
||||
pre_compute.push(quote! {
|
||||
let is_logged_in = <#user_type as crate::api::auth::LoggedInUser>::is_logged_in(&user);
|
||||
if !is_logged_in {
|
||||
return Err(crate::util::error::BackendError::AuthError(
|
||||
crate::util::error::AuthError::Unauthorized));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pre_compute
|
||||
}
|
||||
|
||||
fn build_server_macro(flags: &Flags, attr_args: &Punctuated<Meta, Token![,]>) -> TokenStream2 {
|
||||
// Remove flags from server args
|
||||
let server_args: Punctuated<Meta, Token![,]> = attr_args
|
||||
.clone()
|
||||
.into_iter()
|
||||
.filter(|arg| {
|
||||
match arg {
|
||||
Meta::Path(path) => {
|
||||
!path.is_ident("upload") &&
|
||||
!path.is_ident("admin") &&
|
||||
!path.is_ident("login")
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if flags.upload {
|
||||
quote! {
|
||||
#[server(input = server_fn::codec::MultipartFormData, #server_args)]
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#[server(client = crate::util::serverfn_client::Client, #server_args)]
|
||||
}
|
||||
}
|
||||
}
|
||||
123
src/db_type.rs
Normal file
123
src/db_type.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use proc_macro::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed, Path, parse_macro_input};
|
||||
|
||||
pub fn db_type_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let derive_input = parse_macro_input!(item as DeriveInput);
|
||||
|
||||
// Get struct
|
||||
let original_struct = if let Data::Struct(data_struct) = derive_input.data.clone() {
|
||||
data_struct
|
||||
} else {
|
||||
return syn::Error::new_spanned(
|
||||
derive_input.ident,
|
||||
"#[db_type] can only be applied to structs",
|
||||
)
|
||||
.to_compile_error()
|
||||
.into();
|
||||
};
|
||||
|
||||
// Get fields
|
||||
let fields = if let Fields::Named(fields) = &original_struct.fields {
|
||||
fields.clone()
|
||||
} else {
|
||||
return syn::Error::new_spanned(
|
||||
derive_input.ident,
|
||||
"#[db_type] can only be applied to structs with named fields",
|
||||
)
|
||||
.to_compile_error()
|
||||
.into();
|
||||
};
|
||||
|
||||
let has_id_field = fields.named.iter().any(|field|
|
||||
field.ident.as_ref()
|
||||
.map_or(false, |field| field == "id")
|
||||
);
|
||||
|
||||
// Filter out fields with the `#[omit_new]` attribute
|
||||
let new_fields = fields.named.iter().filter(|field| {
|
||||
!field
|
||||
.attrs
|
||||
.iter()
|
||||
.any(|attr| attr.path().is_ident("omit_new"))
|
||||
});
|
||||
|
||||
// Wrap back into FieldsNamed
|
||||
let new_fields = Fields::Named(FieldsNamed {
|
||||
brace_token: fields.brace_token,
|
||||
named: new_fields.cloned().collect(),
|
||||
});
|
||||
|
||||
// Create a new struct with the filtered fields
|
||||
let new_struct = DataStruct {
|
||||
fields: new_fields,
|
||||
..original_struct.clone()
|
||||
};
|
||||
|
||||
// Create a new DeriveInput with the new struct, named "New<OriginalName>"
|
||||
let new_name = format!("New{}", derive_input.ident.clone());
|
||||
let new_derive_input = DeriveInput {
|
||||
ident: syn::Ident::new(&new_name, derive_input.ident.span()),
|
||||
data: Data::Struct(new_struct),
|
||||
..derive_input.clone()
|
||||
};
|
||||
|
||||
// Remove `#[omit_new]` attributes from the original struct
|
||||
let clean_original_fields = fields.named.iter().map(|field| {
|
||||
let mut new_field = field.clone();
|
||||
new_field
|
||||
.attrs
|
||||
.retain(|attr| !attr.path().is_ident("omit_new"));
|
||||
new_field
|
||||
});
|
||||
|
||||
// Wrap back into FieldsNamed
|
||||
let clean_original_fields = Fields::Named(FieldsNamed {
|
||||
brace_token: fields.brace_token,
|
||||
named: clean_original_fields.collect(),
|
||||
});
|
||||
|
||||
// Create a new struct with the cleaned fields
|
||||
let clean_original_struct = DataStruct {
|
||||
fields: clean_original_fields,
|
||||
..original_struct.clone()
|
||||
};
|
||||
|
||||
// Create a new DeriveInput with the cleaned struct
|
||||
let clean_derive_input = DeriveInput {
|
||||
ident: derive_input.ident.clone(),
|
||||
data: Data::Struct(clean_original_struct),
|
||||
..derive_input.clone()
|
||||
};
|
||||
|
||||
// Get the table path from the attribute
|
||||
let table_path = parse_macro_input!(attr as Path);
|
||||
|
||||
// Don't include the Identifiable derive for structs without an `id` field
|
||||
let full_type_derives = if has_id_field {
|
||||
quote! {
|
||||
diesel::prelude::Queryable,
|
||||
diesel::prelude::Selectable,
|
||||
diesel::prelude::Identifiable
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
diesel::prelude::Queryable,
|
||||
diesel::prelude::Selectable
|
||||
}
|
||||
};
|
||||
|
||||
// Generate the expanded code
|
||||
let expanded = quote! {
|
||||
#[cfg_attr(feature = "ssr", derive(#full_type_derives))]
|
||||
#[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))]
|
||||
#[cfg_attr(feature = "ssr", diesel(table_name = #table_path))]
|
||||
#clean_derive_input
|
||||
#[cfg_attr(feature = "ssr", derive(diesel::prelude::Insertable))]
|
||||
#[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))]
|
||||
#[cfg_attr(feature = "ssr", diesel(table_name = #table_path))]
|
||||
#new_derive_input
|
||||
};
|
||||
|
||||
expanded.into()
|
||||
}
|
||||
128
src/lib.rs
128
src/lib.rs
@@ -1,124 +1,14 @@
|
||||
use proc_macro::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed, Path, parse_macro_input};
|
||||
|
||||
mod db_type;
|
||||
mod api_fn;
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let derive_input = parse_macro_input!(item as DeriveInput);
|
||||
|
||||
// Get struct
|
||||
let original_struct = if let Data::Struct(data_struct) = derive_input.data.clone() {
|
||||
data_struct
|
||||
} else {
|
||||
return syn::Error::new_spanned(
|
||||
derive_input.ident,
|
||||
"#[db_type] can only be applied to structs",
|
||||
)
|
||||
.to_compile_error()
|
||||
.into();
|
||||
};
|
||||
|
||||
// Get fields
|
||||
let fields = if let Fields::Named(fields) = &original_struct.fields {
|
||||
fields.clone()
|
||||
} else {
|
||||
return syn::Error::new_spanned(
|
||||
derive_input.ident,
|
||||
"#[db_type] can only be applied to structs with named fields",
|
||||
)
|
||||
.to_compile_error()
|
||||
.into();
|
||||
};
|
||||
|
||||
let has_id_field = fields.named.iter().any(|field|
|
||||
field.ident.as_ref()
|
||||
.map_or(false, |field| field == "id")
|
||||
);
|
||||
|
||||
// Filter out fields with the `#[omit_new]` attribute
|
||||
let new_fields = fields.named.iter().filter(|field| {
|
||||
!field
|
||||
.attrs
|
||||
.iter()
|
||||
.any(|attr| attr.path().is_ident("omit_new"))
|
||||
});
|
||||
|
||||
// Wrap back into FieldsNamed
|
||||
let new_fields = Fields::Named(FieldsNamed {
|
||||
brace_token: fields.brace_token,
|
||||
named: new_fields.cloned().collect(),
|
||||
});
|
||||
|
||||
// Create a new struct with the filtered fields
|
||||
let new_struct = DataStruct {
|
||||
fields: new_fields,
|
||||
..original_struct.clone()
|
||||
};
|
||||
|
||||
// Create a new DeriveInput with the new struct, named "New<OriginalName>"
|
||||
let new_name = format!("New{}", derive_input.ident.clone());
|
||||
let new_derive_input = DeriveInput {
|
||||
ident: syn::Ident::new(&new_name, derive_input.ident.span()),
|
||||
data: Data::Struct(new_struct),
|
||||
..derive_input.clone()
|
||||
};
|
||||
|
||||
// Remove `#[omit_new]` attributes from the original struct
|
||||
let clean_original_fields = fields.named.iter().map(|field| {
|
||||
let mut new_field = field.clone();
|
||||
new_field
|
||||
.attrs
|
||||
.retain(|attr| !attr.path().is_ident("omit_new"));
|
||||
new_field
|
||||
});
|
||||
|
||||
// Wrap back into FieldsNamed
|
||||
let clean_original_fields = Fields::Named(FieldsNamed {
|
||||
brace_token: fields.brace_token,
|
||||
named: clean_original_fields.collect(),
|
||||
});
|
||||
|
||||
// Create a new struct with the cleaned fields
|
||||
let clean_original_struct = DataStruct {
|
||||
fields: clean_original_fields,
|
||||
..original_struct.clone()
|
||||
};
|
||||
|
||||
// Create a new DeriveInput with the cleaned struct
|
||||
let clean_derive_input = DeriveInput {
|
||||
ident: derive_input.ident.clone(),
|
||||
data: Data::Struct(clean_original_struct),
|
||||
..derive_input.clone()
|
||||
};
|
||||
|
||||
// Get the table path from the attribute
|
||||
let table_path = parse_macro_input!(attr as Path);
|
||||
|
||||
// Don't include the Identifiable derive for structs without an `id` field
|
||||
let full_type_derives = if has_id_field {
|
||||
quote! {
|
||||
diesel::prelude::Queryable,
|
||||
diesel::prelude::Selectable,
|
||||
diesel::prelude::Identifiable
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
diesel::prelude::Queryable,
|
||||
diesel::prelude::Selectable
|
||||
}
|
||||
};
|
||||
|
||||
// Generate the expanded code
|
||||
let expanded = quote! {
|
||||
#[cfg_attr(feature = "ssr", derive(#full_type_derives))]
|
||||
#[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))]
|
||||
#[cfg_attr(feature = "ssr", diesel(table_name = #table_path))]
|
||||
#clean_derive_input
|
||||
#[cfg_attr(feature = "ssr", derive(diesel::prelude::Insertable))]
|
||||
#[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))]
|
||||
#[cfg_attr(feature = "ssr", diesel(table_name = #table_path))]
|
||||
#new_derive_input
|
||||
};
|
||||
|
||||
expanded.into()
|
||||
db_type::db_type_impl(attr, item)
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn api_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
api_fn::api_fn_impl(attr, item)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user