diff --git a/src/db_type.rs b/src/db_type.rs new file mode 100644 index 0000000..3bbf526 --- /dev/null +++ b/src/db_type.rs @@ -0,0 +1,123 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed, Path, parse_macro_input}; + +pub fn db_type_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let derive_input = parse_macro_input!(item as DeriveInput); + + // Get struct + let original_struct = if let Data::Struct(data_struct) = derive_input.data.clone() { + data_struct + } else { + return syn::Error::new_spanned( + derive_input.ident, + "#[db_type] can only be applied to structs", + ) + .to_compile_error() + .into(); + }; + + // Get fields + let fields = if let Fields::Named(fields) = &original_struct.fields { + fields.clone() + } else { + return syn::Error::new_spanned( + derive_input.ident, + "#[db_type] can only be applied to structs with named fields", + ) + .to_compile_error() + .into(); + }; + + let has_id_field = fields.named.iter().any(|field| + field.ident.as_ref() + .map_or(false, |field| field == "id") + ); + + // Filter out fields with the `#[omit_new]` attribute + let new_fields = fields.named.iter().filter(|field| { + !field + .attrs + .iter() + .any(|attr| attr.path().is_ident("omit_new")) + }); + + // Wrap back into FieldsNamed + let new_fields = Fields::Named(FieldsNamed { + brace_token: fields.brace_token, + named: new_fields.cloned().collect(), + }); + + // Create a new struct with the filtered fields + let new_struct = DataStruct { + fields: new_fields, + ..original_struct.clone() + }; + + // Create a new DeriveInput with the new struct, named "New" + let new_name = format!("New{}", derive_input.ident.clone()); + let new_derive_input = DeriveInput { + ident: syn::Ident::new(&new_name, derive_input.ident.span()), + data: Data::Struct(new_struct), + ..derive_input.clone() + }; + + // Remove `#[omit_new]` attributes from the original struct + let clean_original_fields = fields.named.iter().map(|field| { + let mut new_field = field.clone(); + new_field + .attrs + .retain(|attr| !attr.path().is_ident("omit_new")); + new_field + }); + + // Wrap back into FieldsNamed + let clean_original_fields = Fields::Named(FieldsNamed { + brace_token: fields.brace_token, + named: clean_original_fields.collect(), + }); + + // Create a new struct with the cleaned fields + let clean_original_struct = DataStruct { + fields: clean_original_fields, + ..original_struct.clone() + }; + + // Create a new DeriveInput with the cleaned struct + let clean_derive_input = DeriveInput { + ident: derive_input.ident.clone(), + data: Data::Struct(clean_original_struct), + ..derive_input.clone() + }; + + // Get the table path from the attribute + let table_path = parse_macro_input!(attr as Path); + + // Don't include the Identifiable derive for structs without an `id` field + let full_type_derives = if has_id_field { + quote! { + diesel::prelude::Queryable, + diesel::prelude::Selectable, + diesel::prelude::Identifiable + } + } else { + quote! { + diesel::prelude::Queryable, + diesel::prelude::Selectable + } + }; + + // Generate the expanded code + let expanded = quote! { + #[cfg_attr(feature = "ssr", derive(#full_type_derives))] + #[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))] + #[cfg_attr(feature = "ssr", diesel(table_name = #table_path))] + #clean_derive_input + #[cfg_attr(feature = "ssr", derive(diesel::prelude::Insertable))] + #[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))] + #[cfg_attr(feature = "ssr", diesel(table_name = #table_path))] + #new_derive_input + }; + + expanded.into() +} diff --git a/src/lib.rs b/src/lib.rs index c572ebb..2a894d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,124 +1,8 @@ use proc_macro::TokenStream; -use quote::quote; -use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed, Path, parse_macro_input}; + +mod db_type; #[proc_macro_attribute] pub fn db_type(attr: TokenStream, item: TokenStream) -> TokenStream { - let derive_input = parse_macro_input!(item as DeriveInput); - - // Get struct - let original_struct = if let Data::Struct(data_struct) = derive_input.data.clone() { - data_struct - } else { - return syn::Error::new_spanned( - derive_input.ident, - "#[db_type] can only be applied to structs", - ) - .to_compile_error() - .into(); - }; - - // Get fields - let fields = if let Fields::Named(fields) = &original_struct.fields { - fields.clone() - } else { - return syn::Error::new_spanned( - derive_input.ident, - "#[db_type] can only be applied to structs with named fields", - ) - .to_compile_error() - .into(); - }; - - let has_id_field = fields.named.iter().any(|field| - field.ident.as_ref() - .map_or(false, |field| field == "id") - ); - - // Filter out fields with the `#[omit_new]` attribute - let new_fields = fields.named.iter().filter(|field| { - !field - .attrs - .iter() - .any(|attr| attr.path().is_ident("omit_new")) - }); - - // Wrap back into FieldsNamed - let new_fields = Fields::Named(FieldsNamed { - brace_token: fields.brace_token, - named: new_fields.cloned().collect(), - }); - - // Create a new struct with the filtered fields - let new_struct = DataStruct { - fields: new_fields, - ..original_struct.clone() - }; - - // Create a new DeriveInput with the new struct, named "New" - let new_name = format!("New{}", derive_input.ident.clone()); - let new_derive_input = DeriveInput { - ident: syn::Ident::new(&new_name, derive_input.ident.span()), - data: Data::Struct(new_struct), - ..derive_input.clone() - }; - - // Remove `#[omit_new]` attributes from the original struct - let clean_original_fields = fields.named.iter().map(|field| { - let mut new_field = field.clone(); - new_field - .attrs - .retain(|attr| !attr.path().is_ident("omit_new")); - new_field - }); - - // Wrap back into FieldsNamed - let clean_original_fields = Fields::Named(FieldsNamed { - brace_token: fields.brace_token, - named: clean_original_fields.collect(), - }); - - // Create a new struct with the cleaned fields - let clean_original_struct = DataStruct { - fields: clean_original_fields, - ..original_struct.clone() - }; - - // Create a new DeriveInput with the cleaned struct - let clean_derive_input = DeriveInput { - ident: derive_input.ident.clone(), - data: Data::Struct(clean_original_struct), - ..derive_input.clone() - }; - - // Get the table path from the attribute - let table_path = parse_macro_input!(attr as Path); - - // Don't include the Identifiable derive for structs without an `id` field - let full_type_derives = if has_id_field { - quote! { - diesel::prelude::Queryable, - diesel::prelude::Selectable, - diesel::prelude::Identifiable - } - } else { - quote! { - diesel::prelude::Queryable, - diesel::prelude::Selectable - } - }; - - // Generate the expanded code - let expanded = quote! { - #[cfg_attr(feature = "ssr", derive(#full_type_derives))] - #[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))] - #[cfg_attr(feature = "ssr", diesel(table_name = #table_path))] - #clean_derive_input - #[cfg_attr(feature = "ssr", derive(diesel::prelude::Insertable))] - #[cfg_attr(feature = "ssr", diesel(check_for_backend(diesel::pg::Pg)))] - #[cfg_attr(feature = "ssr", diesel(table_name = #table_path))] - #new_derive_input - }; - - expanded.into() + db_type::db_type_impl(attr, item) }