compiler/rustc_builtin_macros/src/autodiff.rs RUST 931 lines View on github.com → Search inside
1//! This module contains the implementation of the `#[autodiff]` attribute.2//! Currently our linter isn't smart enough to see that each import is used in one of the two3//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.4//! FIXME(ZuseZ4): Remove this once we have a smarter linter.56mod llvm_enzyme {7    use std::str::FromStr;8    use std::string::String;910    use rustc_ast::expand::autodiff_attrs::{11        DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity,12    };13    use rustc_ast::token::{Lit, LitKind, Token, TokenKind};14    use rustc_ast::tokenstream::*;15    use rustc_ast::visit::AssocCtxt::*;16    use rustc_ast::{17        self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,18        FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,19        MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility,20    };21    use rustc_expand::base::{Annotatable, ExtCtxt};22    use rustc_hir::attrs::RustcAutodiff;23    use rustc_span::{Ident, Span, Symbol, kw, sym};24    use thin_vec::{ThinVec, thin_vec};25    use tracing::{debug, trace};2627    use crate::errors;2829    pub(crate) fn outer_normal_attr(30        kind: &Box<rustc_ast::NormalAttr>,31        id: rustc_ast::AttrId,32        span: Span,33    ) -> rustc_ast::Attribute {34        let style = rustc_ast::AttrStyle::Outer;35        let kind = rustc_ast::AttrKind::Normal(kind.clone());36        rustc_ast::Attribute { kind, id, style, span }37    }3839    // If we have a default `()` return type or explicitley `()` return type,40    // then we often can skip doing some work.41    fn has_ret(ty: &FnRetTy) -> bool {42        match ty {43            FnRetTy::Ty(ty) => !ty.kind.is_unit(),44            FnRetTy::Default(_) => false,45        }46    }47    fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {48        if let Some(l) = x.lit() {49            match l.kind {50                ast::LitKind::Int(val, _) => {51                    // get an Ident from a lit52                    return rustc_span::Ident::from_str(val.get().to_string().as_str());53                }54                _ => {}55            }56        }5758        let segments = &x.meta_item().unwrap().path.segments;59        assert!(segments.len() == 1);60        segments[0].ident61    }6263    fn name(x: &MetaItemInner) -> String {64        first_ident(x).name.to_string()65    }6667    fn width(x: &MetaItemInner) -> Option<u128> {68        let lit = x.lit()?;69        match lit.kind {70            ast::LitKind::Int(x, _) => Some(x.get()),71            _ => return None,72        }73    }7475    // Get information about the function the macro is applied to76    fn extract_item_info(iitem: &Box<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {77        match &iitem.kind {78            ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {79                Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))80            }81            _ => None,82        }83    }8485    pub(crate) fn from_ast(86        ecx: &mut ExtCtxt<'_>,87        meta_item: &ThinVec<MetaItemInner>,88        has_ret: bool,89        mode: DiffMode,90    ) -> RustcAutodiff {91        let dcx = ecx.sess.dcx();9293        // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.94        // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.95        let mut first_activity = 1;9697        let width = if let [_, x, ..] = &meta_item[..]98            && let Some(x) = width(x)99        {100            first_activity = 2;101            match x.try_into() {102                Ok(x) => x,103                Err(_) => {104                    dcx.emit_err(errors::AutoDiffInvalidWidth {105                        span: meta_item[1].span(),106                        width: x,107                    });108                    return RustcAutodiff::error();109                }110            }111        } else {112            1113        };114115        let mut activities: Vec<DiffActivity> = vec![];116        let mut errors = false;117        for x in &meta_item[first_activity..] {118            let activity_str = name(&x);119            let res = DiffActivity::from_str(&activity_str);120            match res {121                Ok(x) => activities.push(x),122                Err(_) => {123                    dcx.emit_err(errors::AutoDiffUnknownActivity {124                        span: x.span(),125                        act: activity_str,126                    });127                    errors = true;128                }129            };130        }131        if errors {132            return RustcAutodiff::error();133        }134135        // If a return type exist, we need to split the last activity,136        // otherwise we return None as placeholder.137        let (ret_activity, input_activity) = if has_ret {138            let Some((last, rest)) = activities.split_last() else {139                unreachable!(140                    "should not be reachable because we counted the number of activities previously"141                );142            };143            (last, rest)144        } else {145            (&DiffActivity::None, activities.as_slice())146        };147148        RustcAutodiff {149            mode,150            width,151            ret_activity: *ret_activity,152            input_activity: input_activity.iter().cloned().collect(),153        }154    }155156    fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {157        let comma: Token = Token::new(TokenKind::Comma, Span::default());158        let val = first_ident(t);159        let t = Token::from_ast_ident(val);160        ts.push(TokenTree::Token(t, Spacing::Joint));161        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));162    }163164    pub(crate) fn expand_forward(165        ecx: &mut ExtCtxt<'_>,166        expand_span: Span,167        meta_item: &ast::MetaItem,168        item: Annotatable,169    ) -> Vec<Annotatable> {170        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)171    }172173    pub(crate) fn expand_reverse(174        ecx: &mut ExtCtxt<'_>,175        expand_span: Span,176        meta_item: &ast::MetaItem,177        item: Annotatable,178    ) -> Vec<Annotatable> {179        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)180    }181182    /// We expand the autodiff macro to generate a new placeholder function which passes183    /// type-checking and can be called by users. The exact signature of the generated function184    /// depends on the configuration provided by the user, but here is an example:185    ///186    /// ```187    /// #[autodiff(cos_box, Reverse, Duplicated, Active)]188    /// fn sin(x: &Box<f32>) -> f32 {189    ///     f32::sin(**x)190    /// }191    /// ```192    /// which becomes expanded to:193    /// ```194    /// #[rustc_autodiff]195    /// fn sin(x: &Box<f32>) -> f32 {196    ///     f32::sin(**x)197    /// }198    /// #[rustc_autodiff(Reverse, Duplicated, Active)]199    /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {200    ///     std::intrinsics::autodiff(sin::<> as fn(..) -> .., cos_box::<>, (x, dx, dret))201    /// }202    /// ```203    /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked204    /// in CI.205    pub(crate) fn expand_with_mode(206        ecx: &mut ExtCtxt<'_>,207        expand_span: Span,208        meta_item: &ast::MetaItem,209        mut item: Annotatable,210        mode: DiffMode,211    ) -> Vec<Annotatable> {212        let dcx = ecx.sess.dcx();213214        // first get information about the annotable item: visibility, signature, name and generic215        // parameters.216        // these will be used to generate the differentiated version of the function217        let Some((vis, sig, primal, generics, is_impl)) = (match &item {218            Annotatable::Item(iitem) => {219                extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))220            }221            Annotatable::Stmt(stmt) => match &stmt.kind {222                ast::StmtKind::Item(iitem) => {223                    extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))224                }225                _ => None,226            },227            Annotatable::AssocItem(assoc_item, _ctxt @ (Impl { of_trait: _ } | Trait)) => {228                match &assoc_item.kind {229                    ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((230                        assoc_item.vis.clone(),231                        sig.clone(),232                        ident.clone(),233                        generics.clone(),234                        true,235                    )),236                    _ => None,237                }238            }239            _ => None,240        }) else {241            dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });242            return vec![item];243        };244245        let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {246            ast::MetaItemKind::List(ref vec) => vec.clone(),247            _ => {248                dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });249                return vec![item];250            }251        };252253        let has_ret = has_ret(&sig.decl.output);254255        // create TokenStream from vec elemtents:256        // meta_item doesn't have a .tokens field257        let mut ts: Vec<TokenTree> = vec![];258        if meta_item_vec.len() < 1 {259            // At the bare minimum, we need a fnc name.260            dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });261            return vec![item];262        }263264        let mode_symbol = match mode {265            DiffMode::Forward => sym::Forward,266            DiffMode::Reverse => sym::Reverse,267            _ => unreachable!("Unsupported mode: {:?}", mode),268        };269270        // Insert mode token271        let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());272        ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));273        ts.insert(274            1,275            TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),276        );277278        // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.279        // If it is not given, we default to 1 (scalar mode).280        let start_position;281        let kind: LitKind = LitKind::Integer;282        let symbol;283        if meta_item_vec.len() >= 2284            && let Some(width) = width(&meta_item_vec[1])285        {286            start_position = 2;287            symbol = Symbol::intern(&width.to_string());288        } else {289            start_position = 1;290            symbol = sym::integer(1);291        }292293        let l: Lit = Lit { kind, symbol, suffix: None };294        let t = Token::new(TokenKind::Literal(l), Span::default());295        let comma = Token::new(TokenKind::Comma, Span::default());296        ts.push(TokenTree::Token(t, Spacing::Joint));297        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));298299        for t in meta_item_vec.clone()[start_position..].iter() {300            meta_item_inner_to_ts(t, &mut ts);301        }302303        if !has_ret {304            // We don't want users to provide a return activity if the function doesn't return anything.305            // For simplicity, we just add a dummy token to the end of the list.306            let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());307            ts.push(TokenTree::Token(t, Spacing::Joint));308            ts.push(TokenTree::Token(comma, Spacing::Alone));309        }310        // We remove the last, trailing comma.311        ts.pop();312        let ts: TokenStream = TokenStream::from_iter(ts);313314        let x: RustcAutodiff = from_ast(ecx, &meta_item_vec, has_ret, mode);315        if !x.is_active() {316            // We encountered an error, so we return the original item.317            // This allows us to potentially parse other attributes.318            return vec![item];319        }320        let span = ecx.with_def_site_ctxt(expand_span);321322        let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);323324        let d_body = ecx.block(325            span,326            thin_vec![call_autodiff(327                ecx,328                primal,329                first_ident(&meta_item_vec[0]),330                span,331                &sig,332                &d_sig,333                &generics,334                is_impl,335            )],336        );337338        // The first element of it is the name of the function to be generated339        let d_fn = Box::new(ast::Fn {340            defaultness: ast::Defaultness::Implicit,341            sig: d_sig,342            ident: first_ident(&meta_item_vec[0]),343            generics,344            contract: None,345            body: Some(d_body),346            define_opaque: None,347            eii_impls: ThinVec::new(),348        });349        let mut rustc_ad_attr =350            Box::new(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));351352        let ts2: Vec<TokenTree> = vec![TokenTree::Token(353            Token::new(TokenKind::Ident(sym::never, false.into()), span),354            Spacing::Joint,355        )];356        let never_arg = ast::DelimArgs {357            dspan: DelimSpan::from_single(span),358            delim: ast::token::Delimiter::Parenthesis,359            tokens: TokenStream::from_iter(ts2),360        };361        let inline_item = ast::AttrItem {362            unsafety: ast::Safety::Default,363            path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),364            args: rustc_ast::ast::AttrItemKind::Unparsed(ast::AttrArgs::Delimited(never_arg)),365            tokens: None,366        };367        let inline_never_attr = Box::new(ast::NormalAttr { item: inline_item, tokens: None });368        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();369        let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);370        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();371        let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);372373        // We're avoid duplicating the attribute `#[rustc_autodiff]`.374        fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {375            match (attr, item) {376                (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {377                    let a = &a.item.path;378                    let b = &b.item.path;379                    a.segments.iter().eq_by(&b.segments, |a, b| a.ident == b.ident)380                }381                _ => false,382            }383        }384385        let mut has_inline_never = false;386387        // Don't add it multiple times:388        let orig_annotatable: Annotatable = match item {389            Annotatable::Item(ref mut iitem) => {390                if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {391                    iitem.attrs.push(attr);392                }393                if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {394                    has_inline_never = true;395                }396                Annotatable::Item(iitem.clone())397            }398            Annotatable::AssocItem(ref mut assoc_item, ctxt @ (Impl { .. } | Trait)) => {399                if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {400                    assoc_item.attrs.push(attr);401                }402                if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {403                    has_inline_never = true;404                }405                Annotatable::AssocItem(assoc_item.clone(), ctxt)406            }407            Annotatable::Stmt(ref mut stmt) => {408                match stmt.kind {409                    ast::StmtKind::Item(ref mut iitem) => {410                        if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {411                            iitem.attrs.push(attr);412                        }413                        if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {414                            has_inline_never = true;415                        }416                    }417                    _ => unreachable!("stmt kind checked previously"),418                };419420                Annotatable::Stmt(stmt.clone())421            }422            _ => {423                unreachable!("annotatable kind checked previously")424            }425        };426        // Now update for d_fn427        rustc_ad_attr.item.args = rustc_ast::ast::AttrItemKind::Unparsed(428            rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {429                dspan: DelimSpan::dummy(),430                delim: rustc_ast::token::Delimiter::Parenthesis,431                tokens: ts,432            }),433        );434435        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();436        let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);437438        // If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function439        let mut d_attrs = thin_vec![d_attr];440441        if has_inline_never {442            d_attrs.push(inline_never);443        }444445        let d_annotatable = match &item {446            Annotatable::AssocItem(_, ctxt) => {447                let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);448                let d_fn = Box::new(ast::AssocItem {449                    attrs: d_attrs,450                    id: ast::DUMMY_NODE_ID,451                    span,452                    vis,453                    kind: assoc_item,454                    tokens: None,455                });456                Annotatable::AssocItem(d_fn, *ctxt)457            }458            Annotatable::Item(_) => {459                let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));460                d_fn.vis = vis;461462                Annotatable::Item(d_fn)463            }464            Annotatable::Stmt(_) => {465                let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));466                d_fn.vis = vis;467468                Annotatable::Stmt(Box::new(ast::Stmt {469                    id: ast::DUMMY_NODE_ID,470                    kind: ast::StmtKind::Item(d_fn),471                    span,472                }))473            }474            _ => {475                unreachable!("item kind checked previously")476            }477        };478479        return vec![orig_annotatable, d_annotatable];480    }481482    // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be483    // mutable references or ptrs, because Enzyme will write into them.484    fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {485        let mut ty = ty.clone();486        match ty.kind {487            TyKind::Ptr(ref mut mut_ty) => {488                mut_ty.mutbl = ast::Mutability::Mut;489            }490            TyKind::Ref(_, ref mut mut_ty) => {491                mut_ty.mutbl = ast::Mutability::Mut;492            }493            _ => {494                panic!("unsupported type: {:?}", ty);495            }496        }497        ty498    }499500    // Generate `autodiff` intrinsic call501    // ```502    // std::intrinsics::autodiff(source as fn(..) -> .., diff, (args))503    // ```504    fn call_autodiff(505        ecx: &ExtCtxt<'_>,506        primal: Ident,507        diff: Ident,508        span: Span,509        p_sig: &FnSig,510        d_sig: &FnSig,511        generics: &Generics,512        is_impl: bool,513    ) -> rustc_ast::Stmt {514        let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);515516        let self_ty = || ecx.ty_path(ast::Path::from_ident(Ident::with_dummy_span(kw::SelfUpper)));517        let fn_ptr_params: ThinVec<ast::Param> = p_sig518            .decl519            .inputs520            .iter()521            .map(|param| {522                let ty = match &param.ty.kind {523                    TyKind::ImplicitSelf => self_ty(),524                    TyKind::Ref(lt, mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => ecx.ty(525                        span,526                        TyKind::Ref(lt.clone(), ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }),527                    ),528                    TyKind::Ptr(mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => {529                        ecx.ty(span, TyKind::Ptr(ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }))530                    }531                    _ => param.ty.clone(),532                };533                ast::Param {534                    attrs: ast::AttrVec::new(),535                    ty,536                    pat: Box::new(ecx.pat_wild(span)),537                    id: ast::DUMMY_NODE_ID,538                    span,539                    is_placeholder: false,540                }541            })542            .collect();543        let fn_ptr_ty = ecx.ty(544            span,545            TyKind::FnPtr(Box::new(ast::FnPtrTy {546                safety: p_sig.header.safety,547                ext: p_sig.header.ext,548                generic_params: ThinVec::new(),549                decl: Box::new(ast::FnDecl {550                    inputs: fn_ptr_params,551                    output: p_sig.decl.output.clone(),552                }),553                decl_span: span,554            })),555        );556        let primal_fn_ptr = ecx.expr(span, ast::ExprKind::Cast(primal_path_expr, fn_ptr_ty));557558        let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);559560        let tuple_expr = ecx.expr_tuple(561            span,562            d_sig563                .decl564                .inputs565                .iter()566                .map(|arg| match arg.pat.kind {567                    PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),568                    _ => todo!(),569                })570                .collect::<ThinVec<_>>()571                .into(),572        );573574        let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]);575        let enzyme_path = ecx.path(span, enzyme_path_idents);576        let call_expr = ecx.expr_call(577            span,578            ecx.expr_path(enzyme_path),579            vec![primal_fn_ptr, diff_path_expr, tuple_expr].into(),580        );581582        ecx.stmt_expr(call_expr)583    }584585    // Generate turbofish expression from fn name and generics586    // Given `foo` and `<A, B, C>` params, gen `foo::<A, B, C>`587    // We use this expression when passing primal and diff function to the autodiff intrinsic588    fn gen_turbofish_expr(589        ecx: &ExtCtxt<'_>,590        ident: Ident,591        generics: &Generics,592        span: Span,593        is_impl: bool,594    ) -> Box<ast::Expr> {595        let generic_args = generics596            .params597            .iter()598            .filter_map(|p| match &p.kind {599                GenericParamKind::Type { .. } => {600                    let path = ast::Path::from_ident(p.ident);601                    let ty = ecx.ty_path(path);602                    Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))603                }604                GenericParamKind::Const { .. } => {605                    let expr = ecx.expr_path(ast::Path::from_ident(p.ident));606                    let anon_const = AnonConst {607                        id: ast::DUMMY_NODE_ID,608                        value: expr,609                        mgca_disambiguation: MgcaDisambiguation::Direct,610                    };611                    Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))612                }613                GenericParamKind::Lifetime { .. } => None,614            })615            .collect::<ThinVec<_>>();616617        let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args };618619        let segment = PathSegment {620            ident,621            id: ast::DUMMY_NODE_ID,622            args: Some(Box::new(GenericArgs::AngleBracketed(args))),623        };624625        let segments = if is_impl {626            thin_vec![627                PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None },628                segment,629            ]630        } else {631            thin_vec![segment]632        };633634        let path = Path { span, segments, tokens: None };635636        ecx.expr_path(path)637    }638639    // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must640    // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.641    // Active arguments must be scalars. Their shadow argument is added to the return type (and will be642    // zero-initialized by Enzyme).643    // Each argument of the primal function (and the return type if existing) must be annotated with an644    // activity.645    //646    // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or647    // both), we emit an error and return the original signature. This allows us to continue parsing.648    // FIXME(Sa4dUs): make individual activities' span available so errors649    // can point to only the activity instead of the entire attribute650    fn gen_enzyme_decl(651        ecx: &ExtCtxt<'_>,652        sig: &ast::FnSig,653        x: &RustcAutodiff,654        span: Span,655    ) -> ast::FnSig {656        let dcx = ecx.sess.dcx();657        let has_ret = has_ret(&sig.decl.output);658        let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };659        let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };660        if sig_args != num_activities {661            dcx.emit_err(errors::AutoDiffInvalidNumberActivities {662                span,663                expected: sig_args,664                found: num_activities,665            });666            // This is not the right signature, but we can continue parsing.667            return sig.clone();668        }669        assert!(sig.decl.inputs.len() == x.input_activity.len());670        assert!(has_ret == x.has_ret_activity());671        let mut d_decl = sig.decl.clone();672        let mut d_inputs = Vec::new();673        let mut new_inputs = Vec::new();674        let mut idents = Vec::new();675        let mut act_ret = ThinVec::new();676677        // We have two loops, a first one just to check the activities and types and possibly report678        // multiple errors in one compilation session.679        let mut errors = false;680        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {681            if !valid_input_activity(x.mode, *activity) {682                dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {683                    span,684                    mode: x.mode.to_string(),685                    act: activity.to_string(),686                });687                errors = true;688            }689            if !valid_ty_for_activity(&arg.ty, *activity) {690                dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {691                    span: arg.ty.span,692                    act: activity.to_string(),693                });694                errors = true;695            }696        }697698        if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {699            dcx.emit_err(errors::AutoDiffInvalidRetAct {700                span,701                mode: x.mode.to_string(),702                act: x.ret_activity.to_string(),703            });704            // We don't set `errors = true` to avoid annoying type errors relative705            // to the expanded macro type signature706        }707708        if errors {709            // This is not the right signature, but we can continue parsing.710            return sig.clone();711        }712713        let unsafe_activities = x714            .input_activity715            .iter()716            .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));717        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {718            d_inputs.push(arg.clone());719            match activity {720                DiffActivity::Active => {721                    act_ret.push(arg.ty.clone());722                    // if width =/= 1, then push [arg.ty; width] to act_ret723                }724                DiffActivity::ActiveOnly => {725                    // We will add the active scalar to the return type.726                    // This is handled later.727                }728                DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {729                    for i in 0..x.width {730                        let mut shadow_arg = arg.clone();731                        // We += into the shadow in reverse mode.732                        shadow_arg.ty = Box::new(assure_mut_ref(&arg.ty));733                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {734                            ident.name735                        } else {736                            debug!("{:#?}", &shadow_arg.pat);737                            panic!("not an ident?");738                        };739                        let name: String = format!("d{}_{}", old_name, i);740                        new_inputs.push(name.clone());741                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);742                        shadow_arg.pat = Box::new(ast::Pat {743                            id: ast::DUMMY_NODE_ID,744                            kind: PatKind::Ident(BindingMode::NONE, ident, None),745                            span: shadow_arg.pat.span,746                            tokens: shadow_arg.pat.tokens.clone(),747                        });748                        d_inputs.push(shadow_arg.clone());749                    }750                }751                DiffActivity::Dual752                | DiffActivity::DualOnly753                | DiffActivity::Dualv754                | DiffActivity::DualvOnly => {755                    // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause756                    // Enzyme to not expect N arguments, but one argument (which is instead larger).757                    let iterations =758                        if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {759                            1760                        } else {761                            x.width762                        };763                    for i in 0..iterations {764                        let mut shadow_arg = arg.clone();765                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {766                            ident.name767                        } else {768                            debug!("{:#?}", &shadow_arg.pat);769                            panic!("not an ident?");770                        };771                        let name: String = format!("b{}_{}", old_name, i);772                        new_inputs.push(name.clone());773                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);774                        shadow_arg.pat = Box::new(ast::Pat {775                            id: ast::DUMMY_NODE_ID,776                            kind: PatKind::Ident(BindingMode::NONE, ident, None),777                            span: shadow_arg.pat.span,778                            tokens: shadow_arg.pat.tokens.clone(),779                        });780                        d_inputs.push(shadow_arg.clone());781                    }782                }783                DiffActivity::Const => {784                    // Nothing to do here.785                }786                DiffActivity::None | DiffActivity::FakeActivitySize(_) => {787                    panic!("Should not happen");788                }789            }790            if let PatKind::Ident(_, ident, _) = arg.pat.kind {791                idents.push(ident.clone());792            } else {793                panic!("not an ident?");794            }795        }796797        let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;798        if active_only_ret {799            assert!(x.mode.is_rev());800        }801802        // If we return a scalar in the primal and the scalar is active,803        // then add it as last arg to the inputs.804        if x.mode.is_rev() {805            match x.ret_activity {806                DiffActivity::Active | DiffActivity::ActiveOnly => {807                    let ty = match d_decl.output {808                        FnRetTy::Ty(ref ty) => ty.clone(),809                        FnRetTy::Default(span) => {810                            panic!("Did not expect Default ret ty: {:?}", span);811                        }812                    };813                    let name = "dret".to_string();814                    let ident = Ident::from_str_and_span(&name, ty.span);815                    let shadow_arg = ast::Param {816                        attrs: ThinVec::new(),817                        ty: ty.clone(),818                        pat: Box::new(ast::Pat {819                            id: ast::DUMMY_NODE_ID,820                            kind: PatKind::Ident(BindingMode::NONE, ident, None),821                            span: ty.span,822                            tokens: None,823                        }),824                        id: ast::DUMMY_NODE_ID,825                        span: ty.span,826                        is_placeholder: false,827                    };828                    d_inputs.push(shadow_arg);829                    new_inputs.push(name);830                }831                _ => {}832            }833        }834        d_decl.inputs = d_inputs.into();835836        if x.mode.is_fwd() {837            let ty = match d_decl.output {838                FnRetTy::Ty(ref ty) => ty.clone(),839                FnRetTy::Default(span) => {840                    // We want to return std::hint::black_box(()).841                    let kind = TyKind::Tup(ThinVec::new());842                    let ty = Box::new(rustc_ast::Ty {843                        kind,844                        id: ast::DUMMY_NODE_ID,845                        span,846                        tokens: None,847                    });848                    d_decl.output = FnRetTy::Ty(ty.clone());849                    assert!(matches!(x.ret_activity, DiffActivity::None));850                    // this won't be used below, so any type would be fine.851                    ty852                }853            };854855            if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {856                let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {857                    // Dual can only be used for f32/f64 ret.858                    // In that case we return now a tuple with two floats.859                    TyKind::Tup(thin_vec![ty.clone(), ty.clone()])860                } else {861                    // We have to return [T; width+1], +1 for the primal return.862                    let anon_const = rustc_ast::AnonConst {863                        id: ast::DUMMY_NODE_ID,864                        value: ecx.expr_usize(span, 1 + x.width as usize),865                        mgca_disambiguation: MgcaDisambiguation::Direct,866                    };867                    TyKind::Array(ty.clone(), anon_const)868                };869                let ty = Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });870                d_decl.output = FnRetTy::Ty(ty);871            }872            if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {873                // No need to change the return type,874                // we will just return the shadow in place of the primal return.875                // However, if we have a width > 1, then we don't return -> T, but -> [T; width]876                if x.width > 1 {877                    let anon_const = rustc_ast::AnonConst {878                        id: ast::DUMMY_NODE_ID,879                        value: ecx.expr_usize(span, x.width as usize),880                        mgca_disambiguation: MgcaDisambiguation::Direct,881                    };882                    let kind = TyKind::Array(ty.clone(), anon_const);883                    let ty =884                        Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });885                    d_decl.output = FnRetTy::Ty(ty);886                }887            }888        }889890        // If we use ActiveOnly, drop the original return value.891        d_decl.output =892            if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };893894        trace!("act_ret: {:?}", act_ret);895896        // If we have an active input scalar, add it's gradient to the897        // return type. This might require changing the return type to a898        // tuple.899        if act_ret.len() > 0 {900            let ret_ty = match d_decl.output {901                FnRetTy::Ty(ref ty) => {902                    if !active_only_ret {903                        act_ret.insert(0, ty.clone());904                    }905                    let kind = TyKind::Tup(act_ret);906                    Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })907                }908                FnRetTy::Default(span) => {909                    if act_ret.len() == 1 {910                        act_ret[0].clone()911                    } else {912                        let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());913                        Box::new(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })914                    }915                }916            };917            d_decl.output = FnRetTy::Ty(ret_ty);918        }919920        let mut d_header = sig.header.clone();921        if unsafe_activities {922            d_header.safety = rustc_ast::Safety::Unsafe(span);923        }924        let d_sig = FnSig { header: d_header, decl: d_decl, span };925        trace!("Generated signature: {:?}", d_sig);926        d_sig927    }928}929930pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};

Code quality findings 46

Warning: '.unwrap()' will panic on None/Err variants. Prefer using pattern matching (match, if let), combinators (map, and_then), or the '?' operator for robust error handling.
warning correctness unwrap-usage
let segments = &x.meta_item().unwrap().path.segments;
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
segments[0].ident
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
let width = if let [_, x, ..] = &meta_item[..]
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
span: meta_item[1].span(),
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
for x in &meta_item[first_activity..] {
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
&& let Some(width) = width(&meta_item_vec[1])
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
for t in meta_item_vec.clone()[start_position..].iter() {
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
first_ident(&meta_item_vec[0]),
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
ident: first_ident(&meta_item_vec[0]),
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
// if width =/= 1, then push [arg.ty; width] to act_ret
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
// We have to return [T; width+1], +1 for the primal return.
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
act_ret[0].clone()
Info: Wildcard imports (`use some::path::*;`) can obscure the origin of names and lead to conflicts. Prefer importing specific items explicitly.
info maintainability wildcard-import
use rustc_ast::tokenstream::*;
Info: Wildcard imports (`use some::path::*;`) can obscure the origin of names and lead to conflicts. Prefer importing specific items explicitly.
info maintainability wildcard-import
use rustc_ast::visit::AssocCtxt::*;
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
match lit.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
match &iitem.kind {
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info performance push-without-reserve
Ok(x) => activities.push(x),
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
let Some((vis, sig, primal, generics, is_impl)) = (match &item {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
Annotatable::Stmt(stmt) => match &stmt.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
match &assoc_item.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
let mode_symbol = match mode {
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info performance push-without-reserve
ts.push(TokenTree::Token(t, Spacing::Joint));
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info performance push-without-reserve
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info performance clone-in-loop
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info performance clone-in-loop
for t in meta_item_vec.clone()[start_position..].iter() {
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info performance push-without-reserve
ts.push(TokenTree::Token(t, Spacing::Joint));
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info performance push-without-reserve
ts.push(TokenTree::Token(comma, Spacing::Alone));
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
match (attr, item) {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
match stmt.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
match ty.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
let ty = match &param.ty.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
.map(|arg| match arg.pat.kind {
Maintainability Info: `todo!()` or `unimplemented!()` macros indicate incomplete code paths that will panic at runtime if reached. Ensure these are replaced with actual logic before production use.
info correctness todo-unimplemented
_ => todo!(),
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info performance clone-in-loop
let mut d_decl = sig.decl.clone();
Performance Info: Calling .to_string() (especially on &str) allocates a new String. If done repeatedly in loops, consider alternatives like working with &str or using crates like `itoa`/`ryu` for number-to-string conversion.
info performance to-string-in-loop
mode: x.mode.to_string(),
Performance Info: Calling .to_string() (especially on &str) allocates a new String. If done repeatedly in loops, consider alternatives like working with &str or using crates like `itoa`/`ryu` for number-to-string conversion.
info performance to-string-in-loop
act: activity.to_string(),
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info performance clone-in-loop
return sig.clone();
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info performance push-without-reserve
d_inputs.push(arg.clone());
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info performance clone-in-loop
d_inputs.push(arg.clone());
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info performance push-without-reserve
act_ret.push(arg.ty.clone());
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info performance clone-in-loop
act_ret.push(arg.ty.clone());
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info performance clone-in-loop
let mut shadow_arg = arg.clone();
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info performance clone-in-loop
let mut shadow_arg = arg.clone();
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info performance clone-in-loop
new_inputs.push(name.clone());
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info performance push-without-reserve
new_inputs.push(name.clone());

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.