diff --git a/Cargo.lock b/Cargo.lock index 014181e..bdf8bc5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5,3 +5,42 @@ version = 4 [[package]] name = "libretunes_macro" version = "0.1.0" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" diff --git a/Cargo.toml b/Cargo.toml index c7f3773..5e63bdf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,5 @@ edition = "2024" proc-macro = true [dependencies] +quote = "1.0.40" +syn = { version = "2.0.101", features = ["full"] } diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..2c0399f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,108 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed, Path, parse_macro_input}; + +#[proc_macro_attribute] +pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream { + let derive_input = parse_macro_input!(item as DeriveInput); + + // Get struct + let original_struct = if let Data::Struct(data_struct) = derive_input.data.clone() { + data_struct + } else { + return syn::Error::new_spanned( + derive_input.ident, + "#[db_type] can only be applied to structs", + ) + .to_compile_error() + .into(); + }; + + // Get fields + let fields = if let Fields::Named(fields) = &original_struct.fields { + fields.clone() + } else { + return syn::Error::new_spanned( + derive_input.ident, + "#[db_type] can only be applied to structs with named fields", + ) + .to_compile_error() + .into(); + }; + + // Filter out fields with the `#[omit_new]` attribute + let new_fields = fields.named.iter().filter(|field| { + !field + .attrs + .iter() + .any(|attr| attr.path().is_ident("omit_new")) + }); + + // Wrap back into FieldsNamed + let new_fields = Fields::Named(FieldsNamed { + brace_token: fields.brace_token, + named: new_fields.cloned().collect(), + }); + + // Create a new struct with the filtered fields + let new_struct = DataStruct { + fields: new_fields, + ..original_struct.clone() + }; + + // Create a new DeriveInput with the new struct, named "New" + let new_name = format!("New{}", derive_input.ident.clone()); + let new_derive_input = DeriveInput { + ident: syn::Ident::new(&new_name, derive_input.ident.span()), + data: Data::Struct(new_struct), + ..derive_input.clone() + }; + + // Remove `#[omit_new]` attributes from the original struct + let clean_original_fields = fields.named.iter().map(|field| { + let mut new_field = field.clone(); + new_field + .attrs + .retain(|attr| !attr.path().is_ident("omit_new")); + new_field + }); + + // Wrap back into FieldsNamed + let clean_original_fields = Fields::Named(FieldsNamed { + brace_token: fields.brace_token, + named: clean_original_fields.collect(), + }); + + // Create a new struct with the cleaned fields + let clean_original_struct = DataStruct { + fields: clean_original_fields, + ..original_struct.clone() + }; + + // Create a new DeriveInput with the cleaned struct + let clean_derive_input = DeriveInput { + ident: derive_input.ident.clone(), + data: Data::Struct(clean_original_struct), + ..derive_input.clone() + }; + + // Get the table path from the attribute + let table_path = parse_macro_input!(attr as Path); + + // Generate the expanded code + let expanded = quote! { + #[cfg_attr(feature = "ssr", derive( + diesel::prelude::Queryable, + diesel::prelude::Selectable, + diesel::prelude::Identifiable))] + #[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))] + #[cfg_attr(feature = "ssr", diesel(table_name = #table_path))] + #clean_derive_input + #[cfg_attr(feature = "ssr", derive(diesel::prelude::Insertable))] + #[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))] + #[cfg_attr(feature = "ssr", diesel(table_name = #table_path))] + #new_derive_input + }; + + expanded.into() +}