diff options
Diffstat (limited to 'rust/macros')
| -rw-r--r-- | rust/macros/helpers.rs | 86 | ||||
| -rw-r--r-- | rust/macros/pin_data.rs | 168 | ||||
| -rw-r--r-- | rust/macros/quote.rs | 14 | 
3 files changed, 201 insertions, 67 deletions
diff --git a/rust/macros/helpers.rs b/rust/macros/helpers.rs index b2bdd4d8c958..afb0f2e3a36a 100644 --- a/rust/macros/helpers.rs +++ b/rust/macros/helpers.rs @@ -1,6 +1,6 @@  // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{token_stream, Group, TokenTree}; +use proc_macro::{token_stream, Group, Punct, Spacing, TokenStream, TokenTree};  pub(crate) fn try_ident(it: &mut token_stream::IntoIter) -> Option<String> {      if let Some(TokenTree::Ident(ident)) = it.next() { @@ -69,3 +69,87 @@ pub(crate) fn expect_end(it: &mut token_stream::IntoIter) {          panic!("Expected end");      }  } + +pub(crate) struct Generics { +    pub(crate) impl_generics: Vec<TokenTree>, +    pub(crate) ty_generics: Vec<TokenTree>, +} + +/// Parses the given `TokenStream` into `Generics` and the rest. +/// +/// The generics are not present in the rest, but a where clause might remain. +pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) { +    // `impl_generics`, the declared generics with their bounds. +    let mut impl_generics = vec![]; +    // Only the names of the generics, without any bounds. +    let mut ty_generics = vec![]; +    // Tokens not related to the generics e.g. the `where` token and definition. +    let mut rest = vec![]; +    // The current level of `<`. +    let mut nesting = 0; +    let mut toks = input.into_iter(); +    // If we are at the beginning of a generic parameter. +    let mut at_start = true; +    for tt in &mut toks { +        match tt.clone() { +            TokenTree::Punct(p) if p.as_char() == '<' => { +                if nesting >= 1 { +                    // This is inside of the generics and part of some bound. +                    impl_generics.push(tt); +                } +                nesting += 1; +            } +            TokenTree::Punct(p) if p.as_char() == '>' => { +                // This is a parsing error, so we just end it here. +                if nesting == 0 { +                    break; +                } else { +                    nesting -= 1; +                    if nesting >= 1 { +                        // We are still inside of the generics and part of some bound. +                        impl_generics.push(tt); +                    } +                    if nesting == 0 { +                        break; +                    } +                } +            } +            tt => { +                if nesting == 1 { +                    // Here depending on the token, it might be a generic variable name. +                    match &tt { +                        // Ignore const. +                        TokenTree::Ident(i) if i.to_string() == "const" => {} +                        TokenTree::Ident(_) if at_start => { +                            ty_generics.push(tt.clone()); +                            // We also already push the `,` token, this makes it easier to append +                            // generics. +                            ty_generics.push(TokenTree::Punct(Punct::new(',', Spacing::Alone))); +                            at_start = false; +                        } +                        TokenTree::Punct(p) if p.as_char() == ',' => at_start = true, +                        // Lifetimes begin with `'`. +                        TokenTree::Punct(p) if p.as_char() == '\'' && at_start => { +                            ty_generics.push(tt.clone()); +                        } +                        _ => {} +                    } +                } +                if nesting >= 1 { +                    impl_generics.push(tt); +                } else if nesting == 0 { +                    // If we haven't entered the generics yet, we still want to keep these tokens. +                    rest.push(tt); +                } +            } +        } +    } +    rest.extend(toks); +    ( +        Generics { +            impl_generics, +            ty_generics, +        }, +        rest, +    ) +} diff --git a/rust/macros/pin_data.rs b/rust/macros/pin_data.rs index 954149d77181..6d58cfda9872 100644 --- a/rust/macros/pin_data.rs +++ b/rust/macros/pin_data.rs @@ -1,79 +1,127 @@  // SPDX-License-Identifier: Apache-2.0 OR MIT -use proc_macro::{Punct, Spacing, TokenStream, TokenTree}; +use crate::helpers::{parse_generics, Generics}; +use proc_macro::{Group, Punct, Spacing, TokenStream, TokenTree};  pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream {      // This proc-macro only does some pre-parsing and then delegates the actual parsing to      // `kernel::__pin_data!`. -    // -    // In here we only collect the generics, since parsing them in declarative macros is very -    // elaborate. We also do not need to analyse their structure, we only need to collect them. -    // `impl_generics`, the declared generics with their bounds. -    let mut impl_generics = vec![]; -    // Only the names of the generics, without any bounds. -    let mut ty_generics = vec![]; -    // Tokens not related to the generics e.g. the `impl` token. -    let mut rest = vec![]; -    // The current level of `<`. -    let mut nesting = 0; -    let mut toks = input.into_iter(); -    // If we are at the beginning of a generic parameter. -    let mut at_start = true; -    for tt in &mut toks { -        match tt.clone() { -            TokenTree::Punct(p) if p.as_char() == '<' => { -                if nesting >= 1 { -                    impl_generics.push(tt); -                } -                nesting += 1; -            } -            TokenTree::Punct(p) if p.as_char() == '>' => { -                if nesting == 0 { -                    break; -                } else { -                    nesting -= 1; -                    if nesting >= 1 { -                        impl_generics.push(tt); -                    } -                    if nesting == 0 { -                        break; -                    } +    let ( +        Generics { +            impl_generics, +            ty_generics, +        }, +        rest, +    ) = parse_generics(input); +    // The struct definition might contain the `Self` type. Since `__pin_data!` will define a new +    // type with the same generics and bounds, this poses a problem, since `Self` will refer to the +    // new type as opposed to this struct definition. Therefore we have to replace `Self` with the +    // concrete name. + +    // Errors that occur when replacing `Self` with `struct_name`. +    let mut errs = TokenStream::new(); +    // The name of the struct with ty_generics. +    let struct_name = rest +        .iter() +        .skip_while(|tt| !matches!(tt, TokenTree::Ident(i) if i.to_string() == "struct")) +        .nth(1) +        .and_then(|tt| match tt { +            TokenTree::Ident(_) => { +                let tt = tt.clone(); +                let mut res = vec![tt]; +                if !ty_generics.is_empty() { +                    // We add this, so it is maximally compatible with e.g. `Self::CONST` which +                    // will be replaced by `StructName::<$generics>::CONST`. +                    res.push(TokenTree::Punct(Punct::new(':', Spacing::Joint))); +                    res.push(TokenTree::Punct(Punct::new(':', Spacing::Alone))); +                    res.push(TokenTree::Punct(Punct::new('<', Spacing::Alone))); +                    res.extend(ty_generics.iter().cloned()); +                    res.push(TokenTree::Punct(Punct::new('>', Spacing::Alone)));                  } +                Some(res)              } -            tt => { -                if nesting == 1 { -                    match &tt { -                        TokenTree::Ident(i) if i.to_string() == "const" => {} -                        TokenTree::Ident(_) if at_start => { -                            ty_generics.push(tt.clone()); -                            ty_generics.push(TokenTree::Punct(Punct::new(',', Spacing::Alone))); -                            at_start = false; -                        } -                        TokenTree::Punct(p) if p.as_char() == ',' => at_start = true, -                        TokenTree::Punct(p) if p.as_char() == '\'' && at_start => { -                            ty_generics.push(tt.clone()); -                        } -                        _ => {} -                    } -                } -                if nesting >= 1 { -                    impl_generics.push(tt); -                } else if nesting == 0 { -                    rest.push(tt); -                } +            _ => None, +        }) +        .unwrap_or_else(|| { +            // If we did not find the name of the struct then we will use `Self` as the replacement +            // and add a compile error to ensure it does not compile. +            errs.extend( +                "::core::compile_error!(\"Could not locate type name.\");" +                    .parse::<TokenStream>() +                    .unwrap(), +            ); +            "Self".parse::<TokenStream>().unwrap().into_iter().collect() +        }); +    let impl_generics = impl_generics +        .into_iter() +        .flat_map(|tt| replace_self_and_deny_type_defs(&struct_name, tt, &mut errs)) +        .collect::<Vec<_>>(); +    let mut rest = rest +        .into_iter() +        .flat_map(|tt| { +            // We ignore top level `struct` tokens, since they would emit a compile error. +            if matches!(&tt, TokenTree::Ident(i) if i.to_string() == "struct") { +                vec![tt] +            } else { +                replace_self_and_deny_type_defs(&struct_name, tt, &mut errs)              } -        } -    } -    rest.extend(toks); +        }) +        .collect::<Vec<_>>();      // This should be the body of the struct `{...}`.      let last = rest.pop(); -    quote!(::kernel::__pin_data! { +    let mut quoted = quote!(::kernel::__pin_data! {          parse_input:          @args(#args),          @sig(#(#rest)*),          @impl_generics(#(#impl_generics)*),          @ty_generics(#(#ty_generics)*),          @body(#last), -    }) +    }); +    quoted.extend(errs); +    quoted +} + +/// Replaces `Self` with `struct_name` and errors on `enum`, `trait`, `struct` `union` and `impl` +/// keywords. +/// +/// The error is appended to `errs` to allow normal parsing to continue. +fn replace_self_and_deny_type_defs( +    struct_name: &Vec<TokenTree>, +    tt: TokenTree, +    errs: &mut TokenStream, +) -> Vec<TokenTree> { +    match tt { +        TokenTree::Ident(ref i) +            if i.to_string() == "enum" +                || i.to_string() == "trait" +                || i.to_string() == "struct" +                || i.to_string() == "union" +                || i.to_string() == "impl" => +        { +            errs.extend( +                format!( +                    "::core::compile_error!(\"Cannot use `{i}` inside of struct definition with \ +                        `#[pin_data]`.\");" +                ) +                .parse::<TokenStream>() +                .unwrap() +                .into_iter() +                .map(|mut tok| { +                    tok.set_span(tt.span()); +                    tok +                }), +            ); +            vec![tt] +        } +        TokenTree::Ident(i) if i.to_string() == "Self" => struct_name.clone(), +        TokenTree::Literal(_) | TokenTree::Punct(_) | TokenTree::Ident(_) => vec![tt], +        TokenTree::Group(g) => vec![TokenTree::Group(Group::new( +            g.delimiter(), +            g.stream() +                .into_iter() +                .flat_map(|tt| replace_self_and_deny_type_defs(struct_name, tt, errs)) +                .collect(), +        ))], +    }  } diff --git a/rust/macros/quote.rs b/rust/macros/quote.rs index c8e08b3c1e4c..dddbb4e6f4cb 100644 --- a/rust/macros/quote.rs +++ b/rust/macros/quote.rs @@ -39,12 +39,14 @@ impl ToTokens for TokenStream {  /// [`quote_spanned!`](https://docs.rs/quote/latest/quote/macro.quote_spanned.html) macro from the  /// `quote` crate but provides only just enough functionality needed by the current `macros` crate.  macro_rules! quote_spanned { -    ($span:expr => $($tt:tt)*) => { -    #[allow(clippy::vec_init_then_push)] -    { -        let mut tokens = ::std::vec::Vec::new(); -        let span = $span; -        quote_spanned!(@proc tokens span $($tt)*); +    ($span:expr => $($tt:tt)*) => {{ +        let mut tokens; +        #[allow(clippy::vec_init_then_push)] +        { +            tokens = ::std::vec::Vec::new(); +            let span = $span; +            quote_spanned!(@proc tokens span $($tt)*); +        }          ::proc_macro::TokenStream::from_iter(tokens)      }};      (@proc $v:ident $span:ident) => {};  | 
