Skip to main content

karyon_jsonrpc_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5    parse_macro_input, spanned::Spanned, FnArg, ImplItem, ItemFn, ItemImpl, LitStr, ReturnType,
6    Signature, Type, TypePath,
7};
8
9#[proc_macro_attribute]
10pub fn rpc_method(_attr: TokenStream, item: TokenStream) -> TokenStream {
11    let item_fn = parse_macro_input!(item as ItemFn);
12    TokenStream::from(quote! {
13        #item_fn
14    })
15}
16
17macro_rules! err {
18    ($($tt:tt)*) => {
19        return syn::Error::new($($tt)*).to_compile_error().into()
20    };
21}
22
23#[proc_macro_attribute]
24pub fn rpc_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
25    expand_service_impl(attr, item, ServiceKind::Rpc)
26}
27
28#[proc_macro_attribute]
29pub fn rpc_pubsub_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
30    expand_service_impl(attr, item, ServiceKind::PubSub)
31}
32
33#[derive(Clone, Copy)]
34enum ServiceKind {
35    Rpc,
36    PubSub,
37}
38
39impl ServiceKind {
40    fn arity(self) -> usize {
41        match self {
42            // &self + serde_json::Value
43            ServiceKind::Rpc => 2,
44            // &self + Arc<Channel>, String, serde_json::Value
45            ServiceKind::PubSub => 4,
46        }
47    }
48
49    fn arity_error(self) -> &'static str {
50        match self {
51            ServiceKind::Rpc => "requires `&self` and a parameter of type `serde_json::Value`",
52            ServiceKind::PubSub => {
53                "requires `&self` and three parameters: \
54                `Arc<Channel>`, method: `String`, and `serde_json::Value`"
55            }
56        }
57    }
58
59    fn unsupported_self_type_error(self) -> &'static str {
60        match self {
61            ServiceKind::Rpc => "Implementing the trait `RPCService` on this type is unsupported",
62            ServiceKind::PubSub => {
63                "Implementing the trait `PubSubRPCService` on this type is unsupported"
64            }
65        }
66    }
67
68    fn dispatch_arm(self, name: &str, ident: &syn::Ident) -> TokenStream2 {
69        match self {
70            ServiceKind::Rpc => quote! {
71                #name => Some(Box::new(
72                    move |params: serde_json::Value| Box::pin(self.#ident(params))
73                )),
74            },
75            ServiceKind::PubSub => quote! {
76                #name => Some(Box::new(
77                    move |chan: std::sync::Arc<karyon_jsonrpc::server::channel::Channel>,
78                          method: String,
79                          params: serde_json::Value| {
80                        Box::pin(self.#ident(chan, method, params))
81                    }
82                )),
83            },
84        }
85    }
86
87    fn impl_block(
88        self,
89        self_ty: &TypePath,
90        service_name: &str,
91        arms: &[TokenStream2],
92    ) -> TokenStream2 {
93        match self {
94            ServiceKind::Rpc => quote! {
95                impl karyon_jsonrpc::server::RPCService for #self_ty {
96                    fn get_method(
97                        &self,
98                        name: &str,
99                    ) -> Option<karyon_jsonrpc::server::RPCMethod> {
100                        match name {
101                            #(#arms)*
102                            _ => None,
103                        }
104                    }
105                    fn name(&self) -> String {
106                        #service_name.to_string()
107                    }
108                }
109            },
110            ServiceKind::PubSub => quote! {
111                impl karyon_jsonrpc::server::PubSubRPCService for #self_ty {
112                    fn get_pubsub_method(
113                        &self,
114                        name: &str,
115                    ) -> Option<karyon_jsonrpc::server::PubSubRPCMethod> {
116                        match name {
117                            #(#arms)*
118                            _ => None,
119                        }
120                    }
121                    fn name(&self) -> String {
122                        #service_name.to_string()
123                    }
124                }
125            },
126        }
127    }
128}
129
130fn expand_service_impl(attr: TokenStream, item: TokenStream, kind: ServiceKind) -> TokenStream {
131    let item2 = item.clone();
132    let parsed_input = parse_macro_input!(item2 as ItemImpl);
133
134    let self_ty = match *parsed_input.self_ty {
135        Type::Path(p) => p,
136        _ => err!(parsed_input.span(), kind.unsupported_self_type_error()),
137    };
138
139    let service_name = match resolve_service_name(attr, &self_ty) {
140        Ok(name) => name,
141        Err(err) => return err.to_compile_error().into(),
142    };
143
144    let methods = match parse_struct_methods(&self_ty, parsed_input.items) {
145        Ok(res) => res,
146        Err(err) => return err.to_compile_error().into(),
147    };
148
149    let mut arms = Vec::with_capacity(methods.len());
150    for (rename, sig) in methods.iter() {
151        if sig.inputs.len() != kind.arity() {
152            err!(sig.span(), kind.arity_error());
153        }
154        let name = rename.clone().unwrap_or(sig.ident.to_string());
155        arms.push(kind.dispatch_arm(&name, &sig.ident));
156    }
157
158    let impl_block = kind.impl_block(&self_ty, &service_name, &arms);
159    let original: TokenStream2 = item.into();
160    quote! {
161        #impl_block
162        #original
163    }
164    .into()
165}
166
167fn resolve_service_name(attr: TokenStream, self_ty: &TypePath) -> Result<String, syn::Error> {
168    if !attr.is_empty() {
169        let parsed_attr: syn::Meta = syn::parse(attr)?;
170        if let Some(name) = parse_service_name(parsed_attr)? {
171            return Ok(name);
172        }
173    }
174    Ok(self_ty.path.require_ident()?.to_string())
175}
176
177fn parse_struct_methods(
178    self_ty: &TypePath,
179    items: Vec<ImplItem>,
180) -> Result<Vec<(Option<String>, Signature)>, syn::Error> {
181    let mut methods: Vec<(Option<String>, Signature)> = vec![];
182
183    if items.is_empty() {
184        return Err(syn::Error::new(
185            self_ty.span(),
186            "At least one method should be implemented",
187        ));
188    }
189
190    for item in items {
191        match item {
192            ImplItem::Fn(method) => {
193                let mut rpc_method_name = None;
194                validate_method(&method.sig)?;
195
196                for attr in method.attrs {
197                    if attr.path().is_ident("rpc_method") {
198                        attr.parse_nested_meta(|meta| {
199                            if meta.path.is_ident("name") {
200                                let value = meta.value()?;
201                                let s: LitStr = value.parse()?;
202                                if s.value().is_empty() {
203                                    return Err(syn::Error::new(attr.span(), "Empty string"));
204                                }
205                                rpc_method_name = Some(s.value().clone());
206                                Ok(())
207                            } else {
208                                Err(syn::Error::new(attr.span(), "Unexpected attribute"))
209                            }
210                        })?;
211                        break;
212                    }
213                }
214
215                methods.push((rpc_method_name, method.sig));
216            }
217            _ => return Err(syn::Error::new(item.span(), "Unexpected item!")),
218        }
219    }
220
221    Ok(methods)
222}
223
224fn validate_method(method: &Signature) -> Result<(), syn::Error> {
225    if let FnArg::Typed(_) = method.inputs[0] {
226        return Err(syn::Error::new(method.span(), "requires `&self` parameter"));
227    }
228
229    if let ReturnType::Default = method.output {
230        return Err(syn::Error::new(
231            method.span(),
232            "requires `Result<serde_json::Value, RPCError>` as return type",
233        ));
234    }
235    Ok(())
236}
237
238fn parse_service_name(attr: syn::Meta) -> Result<Option<String>, syn::Error> {
239    if let syn::Meta::NameValue(ref n) = attr {
240        if n.path.is_ident("name") {
241            if let syn::Expr::Lit(lit) = &n.value {
242                if let syn::Lit::Str(lit_str) = &lit.lit {
243                    if lit_str.value().is_empty() {
244                        return Err(syn::Error::new(attr.span(), "Empty string"));
245                    }
246                    return Ok(Some(lit_str.value().to_string()));
247                }
248            }
249        } else {
250            return Err(syn::Error::new(attr.span(), "Unexpected attribute"));
251        }
252    }
253    Ok(None)
254}