aboutsummaryrefslogtreecommitdiff
path: root/tosin-macros/src/lib.rs
blob: c9e352c95a82cbf1c73f249afa7ed7b7b5d4964d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#![feature(proc_macro_span)]

extern crate proc_macro;

use std::collections::HashMap;
use std::fs;

use proc_macro::TokenStream;
use quote::quote;

#[proc_macro_derive(Model, attributes(model))]
pub fn model_derive(input: TokenStream) -> TokenStream {
    // Construct a representation of Rust code as a syntax tree
    // that we can manipulate
    let ast = syn::parse(input).unwrap();

    // Build the trait implementation
    impl_model(&ast)
}

// TODO clean all this shit up

fn to_field_spec(field: &syn::Field) -> (impl quote::ToTokens, impl quote::ToTokens) {
    fn parse_type(ty: &str) -> syn::Type {
        syn::parse_str(ty).unwrap()
    }
    let field_name = &field.ident;
    let field_type = &field.ty;
    let model_options: HashMap<syn::Path, syn::Lit> = field.attrs.iter()
        .filter_map(|attr| attr.parse_meta().ok())
        .filter_map(|meta| if let syn::Meta::List(meta) = meta { Some(meta) } else { None })
        .filter(|meta| meta.path.get_ident().map_or(false, |path| path == "model"))
        .flat_map(|model| {
            model.nested.into_iter().filter_map(|item| {
                if let syn::NestedMeta::Meta(syn::Meta::NameValue(data)) = item {
                    Some((data.path, data.lit))
                } else {
                    None
                }
            })
        })
        .collect();
    if field_type == &parse_type("Option<Id>") {
        ( 
            quote! { ::tosin::db::models::Field::IntField { name: stringify!(#field_name) } },
            quote! { #field_name -> Integer }
        )
    } else if field_type == &parse_type("Id") {
        // TODO foreign key constraint
        ( 
            quote! { ::tosin::db::models::Field::IntField { name: stringify!(#field_name) } },
            quote! { #field_name -> Integer }
        )
    } else if field_type == &parse_type("usize") {
        // TODO default
        ( 
            quote! { ::tosin::db::models::Field::IntField { name: stringify!(#field_name) } },
            quote! { #field_name -> Integer }
        )
    } else if field_type == &parse_type("String") {
        let max_length = model_options.iter()
            .find(|(name, _value)| name.get_ident().map_or(false, |path| path == "max_length"))
            .map(|(_name, value)| value);
        if let Some(max_length) = max_length {
            (
                quote! { ::tosin::db::models::Field::CharField { name: stringify!(#field_name), max_length: Some(#max_length) } },
                quote! { #field_name -> Text }
            )
        } else {
            (
                quote! { ::tosin::db::models::Field::CharField { name: stringify!(#field_name), max_length: None } },
                quote! { #field_name -> Text }
            )
        }
    } else if field_type == &parse_type("time::PrimitiveDateTime") {
        (
            quote! { ::tosin::db::models::Field::DateTimeField { name: stringify!(#field_name) } },
            quote! { #field_name -> Timestamp }
        )
    } else {
        use quote::ToTokens;
        panic!("can't handle {}", field.to_token_stream())
    }
}

fn impl_model(ast: &syn::DeriveInput) -> TokenStream {
    let name = &ast.ident;
    let lowercase_name = quote::format_ident!("{}", name.to_string().to_lowercase());
    let ast_data = if let syn::Data::Struct(ast_data) = &ast.data {
        ast_data
    } else {
        panic!("not on a struct");
    };
    let (tosin_fields, diesel_columns): (Vec<_>, Vec<_>) = ast_data.fields.iter().map(to_field_spec).unzip();
    let gen = quote! {
        impl #name {
            pub const META: ::tosin::db::models::ModelMeta = ::tosin::db::models::ModelMeta {
                name: stringify!(#name),
                fields: &[ #(#tosin_fields),* ],
            };
        }

        // this means users need #[macro_use] extern crate diesel; but fuck doing it ourselves
        table! {
            #lowercase_name {
                #(#diesel_columns,)*
            }
        }
    };
    gen.into()
}

#[proc_macro]
pub fn gather_migrations(_input: TokenStream) -> TokenStream {
    let call_site = proc_macro::Span::call_site();
    let call_site_file = call_site.source_file();
    let call_site_path = call_site_file.path();
    if !call_site_file.is_real() {
        panic!("call site does not have a real path");
    }

    let migrations_dir = call_site_path.parent().unwrap();
    let migrations: Vec<syn::Ident> = migrations_dir.read_dir()
        .unwrap()
        .map(Result::unwrap)
        .map(|x| x.path().file_stem().unwrap().to_string_lossy().into_owned())
        .filter(|x| x != "mod")
        .map(|x| syn::parse_str(&x).unwrap())
        .collect();

    let gen = quote! {
        #( mod #migrations; )*

        pub const ALL: &[Migration] = &[
            #(#migrations::MIGRATION),*
        ];
    };

    gen.into()
}

#[proc_macro]
pub fn gather_models(_input: TokenStream) -> TokenStream {
    let call_site = proc_macro::Span::call_site();
    let call_site_file = call_site.source_file();
    let call_site_path = call_site_file.path();
    if !call_site_file.is_real() {
        panic!("call site does not have a real path");
    }

    let call_site_ast = syn::parse_file(&fs::read_to_string(call_site_path).unwrap()).unwrap();
    let models = call_site_ast.items.iter()
        .filter_map(|item| if let syn::Item::Struct(item) = item { Some(item) } else { None })
        .filter(|item| item.attrs.iter().any(|attr| {
            let attr = if let Ok(syn::Meta::List(attr)) = attr.parse_meta() { attr } else { return false; };
            if attr.path.get_ident().map_or(false, |hopefully_derive| hopefully_derive == "derive") {
                let mut derived = attr.nested.iter()
                    .filter_map(|derived| if let syn::NestedMeta::Meta(derived) = derived { Some(derived) } else { None });
                derived.any(|derived| derived.path().get_ident().map_or(false, |hopefully_model| hopefully_model == "Model"))
            } else {
                false
            }
        }))
        .map(|item| &item.ident);

    let gen = quote! {
        pub const ALL: &[tosin::db::models::ModelMeta] = &[
            #(#models::META),*
        ];
    };

    gen.into()
}