1//! This module contains the implementation of the `#[autodiff]` attribute.2//! Currently our linter isn't smart enough to see that each import is used in one of the two3//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.4//! FIXME(ZuseZ4): Remove this once we have a smarter linter.56mod llvm_enzyme {7 use std::str::FromStr;8 use std::string::String;910 use rustc_ast::expand::autodiff_attrs::{11 DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity,12 };13 use rustc_ast::token::{Lit, LitKind, Token, TokenKind};14 use rustc_ast::tokenstream::*;15 use rustc_ast::visit::AssocCtxt::*;16 use rustc_ast::{17 self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,18 FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,19 MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility,20 };21 use rustc_expand::base::{Annotatable, ExtCtxt};22 use rustc_hir::attrs::RustcAutodiff;23 use rustc_span::{Ident, Span, Symbol, kw, sym};24 use thin_vec::{ThinVec, thin_vec};25 use tracing::{debug, trace};2627 use crate::errors;2829 pub(crate) fn outer_normal_attr(30 kind: &Box<rustc_ast::NormalAttr>,31 id: rustc_ast::AttrId,32 span: Span,33 ) -> rustc_ast::Attribute {34 let style = rustc_ast::AttrStyle::Outer;35 let kind = rustc_ast::AttrKind::Normal(kind.clone());36 rustc_ast::Attribute { kind, id, style, span }37 }3839 // If we have a default `()` return type or explicitley `()` return type,40 // then we often can skip doing some work.41 fn has_ret(ty: &FnRetTy) -> bool {42 match ty {43 FnRetTy::Ty(ty) => !ty.kind.is_unit(),44 FnRetTy::Default(_) => false,45 }46 }47 fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {48 if let Some(l) = x.lit() {49 match l.kind {50 ast::LitKind::Int(val, _) => {51 // get an Ident from a lit52 return rustc_span::Ident::from_str(val.get().to_string().as_str());53 }54 _ => {}55 }56 }5758 let segments = &x.meta_item().unwrap().path.segments;59 assert!(segments.len() == 1);60 segments[0].ident61 }6263 fn name(x: &MetaItemInner) -> String {64 first_ident(x).name.to_string()65 }6667 fn width(x: &MetaItemInner) -> Option<u128> {68 let lit = x.lit()?;69 match lit.kind {70 ast::LitKind::Int(x, _) => Some(x.get()),71 _ => return None,72 }73 }7475 // Get information about the function the macro is applied to76 fn extract_item_info(iitem: &Box<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {77 match &iitem.kind {78 ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {79 Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))80 }81 _ => None,82 }83 }8485 pub(crate) fn from_ast(86 ecx: &mut ExtCtxt<'_>,87 meta_item: &ThinVec<MetaItemInner>,88 has_ret: bool,89 mode: DiffMode,90 ) -> RustcAutodiff {91 let dcx = ecx.sess.dcx();9293 // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.94 // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.95 let mut first_activity = 1;9697 let width = if let [_, x, ..] = &meta_item[..]98 && let Some(x) = width(x)99 {100 first_activity = 2;101 match x.try_into() {102 Ok(x) => x,103 Err(_) => {104 dcx.emit_err(errors::AutoDiffInvalidWidth {105 span: meta_item[1].span(),106 width: x,107 });108 return RustcAutodiff::error();109 }110 }111 } else {112 1113 };114115 let mut activities: Vec<DiffActivity> = vec![];116 let mut errors = false;117 for x in &meta_item[first_activity..] {118 let activity_str = name(&x);119 let res = DiffActivity::from_str(&activity_str);120 match res {121 Ok(x) => activities.push(x),122 Err(_) => {123 dcx.emit_err(errors::AutoDiffUnknownActivity {124 span: x.span(),125 act: activity_str,126 });127 errors = true;128 }129 };130 }131 if errors {132 return RustcAutodiff::error();133 }134135 // If a return type exist, we need to split the last activity,136 // otherwise we return None as placeholder.137 let (ret_activity, input_activity) = if has_ret {138 let Some((last, rest)) = activities.split_last() else {139 unreachable!(140 "should not be reachable because we counted the number of activities previously"141 );142 };143 (last, rest)144 } else {145 (&DiffActivity::None, activities.as_slice())146 };147148 RustcAutodiff {149 mode,150 width,151 ret_activity: *ret_activity,152 input_activity: input_activity.iter().cloned().collect(),153 }154 }155156 fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {157 let comma: Token = Token::new(TokenKind::Comma, Span::default());158 let val = first_ident(t);159 let t = Token::from_ast_ident(val);160 ts.push(TokenTree::Token(t, Spacing::Joint));161 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));162 }163164 pub(crate) fn expand_forward(165 ecx: &mut ExtCtxt<'_>,166 expand_span: Span,167 meta_item: &ast::MetaItem,168 item: Annotatable,169 ) -> Vec<Annotatable> {170 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)171 }172173 pub(crate) fn expand_reverse(174 ecx: &mut ExtCtxt<'_>,175 expand_span: Span,176 meta_item: &ast::MetaItem,177 item: Annotatable,178 ) -> Vec<Annotatable> {179 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)180 }181182 /// We expand the autodiff macro to generate a new placeholder function which passes183 /// type-checking and can be called by users. The exact signature of the generated function184 /// depends on the configuration provided by the user, but here is an example:185 ///186 /// ```187 /// #[autodiff(cos_box, Reverse, Duplicated, Active)]188 /// fn sin(x: &Box<f32>) -> f32 {189 /// f32::sin(**x)190 /// }191 /// ```192 /// which becomes expanded to:193 /// ```194 /// #[rustc_autodiff]195 /// fn sin(x: &Box<f32>) -> f32 {196 /// f32::sin(**x)197 /// }198 /// #[rustc_autodiff(Reverse, Duplicated, Active)]199 /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {200 /// std::intrinsics::autodiff(sin::<> as fn(..) -> .., cos_box::<>, (x, dx, dret))201 /// }202 /// ```203 /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked204 /// in CI.205 pub(crate) fn expand_with_mode(206 ecx: &mut ExtCtxt<'_>,207 expand_span: Span,208 meta_item: &ast::MetaItem,209 mut item: Annotatable,210 mode: DiffMode,211 ) -> Vec<Annotatable> {212 let dcx = ecx.sess.dcx();213214 // first get information about the annotable item: visibility, signature, name and generic215 // parameters.216 // these will be used to generate the differentiated version of the function217 let Some((vis, sig, primal, generics, is_impl)) = (match &item {218 Annotatable::Item(iitem) => {219 extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))220 }221 Annotatable::Stmt(stmt) => match &stmt.kind {222 ast::StmtKind::Item(iitem) => {223 extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))224 }225 _ => None,226 },227 Annotatable::AssocItem(assoc_item, _ctxt @ (Impl { of_trait: _ } | Trait)) => {228 match &assoc_item.kind {229 ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((230 assoc_item.vis.clone(),231 sig.clone(),232 ident.clone(),233 generics.clone(),234 true,235 )),236 _ => None,237 }238 }239 _ => None,240 }) else {241 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });242 return vec![item];243 };244245 let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {246 ast::MetaItemKind::List(ref vec) => vec.clone(),247 _ => {248 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });249 return vec![item];250 }251 };252253 let has_ret = has_ret(&sig.decl.output);254255 // create TokenStream from vec elemtents:256 // meta_item doesn't have a .tokens field257 let mut ts: Vec<TokenTree> = vec![];258 if meta_item_vec.len() < 1 {259 // At the bare minimum, we need a fnc name.260 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });261 return vec![item];262 }263264 let mode_symbol = match mode {265 DiffMode::Forward => sym::Forward,266 DiffMode::Reverse => sym::Reverse,267 _ => unreachable!("Unsupported mode: {:?}", mode),268 };269270 // Insert mode token271 let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());272 ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));273 ts.insert(274 1,275 TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),276 );277278 // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.279 // If it is not given, we default to 1 (scalar mode).280 let start_position;281 let kind: LitKind = LitKind::Integer;282 let symbol;283 if meta_item_vec.len() >= 2284 && let Some(width) = width(&meta_item_vec[1])285 {286 start_position = 2;287 symbol = Symbol::intern(&width.to_string());288 } else {289 start_position = 1;290 symbol = sym::integer(1);291 }292293 let l: Lit = Lit { kind, symbol, suffix: None };294 let t = Token::new(TokenKind::Literal(l), Span::default());295 let comma = Token::new(TokenKind::Comma, Span::default());296 ts.push(TokenTree::Token(t, Spacing::Joint));297 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));298299 for t in meta_item_vec.clone()[start_position..].iter() {300 meta_item_inner_to_ts(t, &mut ts);301 }302303 if !has_ret {304 // We don't want users to provide a return activity if the function doesn't return anything.305 // For simplicity, we just add a dummy token to the end of the list.306 let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());307 ts.push(TokenTree::Token(t, Spacing::Joint));308 ts.push(TokenTree::Token(comma, Spacing::Alone));309 }310 // We remove the last, trailing comma.311 ts.pop();312 let ts: TokenStream = TokenStream::from_iter(ts);313314 let x: RustcAutodiff = from_ast(ecx, &meta_item_vec, has_ret, mode);315 if !x.is_active() {316 // We encountered an error, so we return the original item.317 // This allows us to potentially parse other attributes.318 return vec![item];319 }320 let span = ecx.with_def_site_ctxt(expand_span);321322 let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);323324 let d_body = ecx.block(325 span,326 thin_vec![call_autodiff(327 ecx,328 primal,329 first_ident(&meta_item_vec[0]),330 span,331 &sig,332 &d_sig,333 &generics,334 is_impl,335 )],336 );337338 // The first element of it is the name of the function to be generated339 let d_fn = Box::new(ast::Fn {340 defaultness: ast::Defaultness::Implicit,341 sig: d_sig,342 ident: first_ident(&meta_item_vec[0]),343 generics,344 contract: None,345 body: Some(d_body),346 define_opaque: None,347 eii_impls: ThinVec::new(),348 });349 let mut rustc_ad_attr =350 Box::new(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));351352 let ts2: Vec<TokenTree> = vec![TokenTree::Token(353 Token::new(TokenKind::Ident(sym::never, false.into()), span),354 Spacing::Joint,355 )];356 let never_arg = ast::DelimArgs {357 dspan: DelimSpan::from_single(span),358 delim: ast::token::Delimiter::Parenthesis,359 tokens: TokenStream::from_iter(ts2),360 };361 let inline_item = ast::AttrItem {362 unsafety: ast::Safety::Default,363 path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),364 args: rustc_ast::ast::AttrItemKind::Unparsed(ast::AttrArgs::Delimited(never_arg)),365 tokens: None,366 };367 let inline_never_attr = Box::new(ast::NormalAttr { item: inline_item, tokens: None });368 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();369 let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);370 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();371 let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);372373 // We're avoid duplicating the attribute `#[rustc_autodiff]`.374 fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {375 match (attr, item) {376 (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {377 let a = &a.item.path;378 let b = &b.item.path;379 a.segments.iter().eq_by(&b.segments, |a, b| a.ident == b.ident)380 }381 _ => false,382 }383 }384385 let mut has_inline_never = false;386387 // Don't add it multiple times:388 let orig_annotatable: Annotatable = match item {389 Annotatable::Item(ref mut iitem) => {390 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {391 iitem.attrs.push(attr);392 }393 if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {394 has_inline_never = true;395 }396 Annotatable::Item(iitem.clone())397 }398 Annotatable::AssocItem(ref mut assoc_item, ctxt @ (Impl { .. } | Trait)) => {399 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {400 assoc_item.attrs.push(attr);401 }402 if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {403 has_inline_never = true;404 }405 Annotatable::AssocItem(assoc_item.clone(), ctxt)406 }407 Annotatable::Stmt(ref mut stmt) => {408 match stmt.kind {409 ast::StmtKind::Item(ref mut iitem) => {410 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {411 iitem.attrs.push(attr);412 }413 if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {414 has_inline_never = true;415 }416 }417 _ => unreachable!("stmt kind checked previously"),418 };419420 Annotatable::Stmt(stmt.clone())421 }422 _ => {423 unreachable!("annotatable kind checked previously")424 }425 };426 // Now update for d_fn427 rustc_ad_attr.item.args = rustc_ast::ast::AttrItemKind::Unparsed(428 rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {429 dspan: DelimSpan::dummy(),430 delim: rustc_ast::token::Delimiter::Parenthesis,431 tokens: ts,432 }),433 );434435 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();436 let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);437438 // If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function439 let mut d_attrs = thin_vec![d_attr];440441 if has_inline_never {442 d_attrs.push(inline_never);443 }444445 let d_annotatable = match &item {446 Annotatable::AssocItem(_, ctxt) => {447 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);448 let d_fn = Box::new(ast::AssocItem {449 attrs: d_attrs,450 id: ast::DUMMY_NODE_ID,451 span,452 vis,453 kind: assoc_item,454 tokens: None,455 });456 Annotatable::AssocItem(d_fn, *ctxt)457 }458 Annotatable::Item(_) => {459 let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));460 d_fn.vis = vis;461462 Annotatable::Item(d_fn)463 }464 Annotatable::Stmt(_) => {465 let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));466 d_fn.vis = vis;467468 Annotatable::Stmt(Box::new(ast::Stmt {469 id: ast::DUMMY_NODE_ID,470 kind: ast::StmtKind::Item(d_fn),471 span,472 }))473 }474 _ => {475 unreachable!("item kind checked previously")476 }477 };478479 return vec![orig_annotatable, d_annotatable];480 }481482 // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be483 // mutable references or ptrs, because Enzyme will write into them.484 fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {485 let mut ty = ty.clone();486 match ty.kind {487 TyKind::Ptr(ref mut mut_ty) => {488 mut_ty.mutbl = ast::Mutability::Mut;489 }490 TyKind::Ref(_, ref mut mut_ty) => {491 mut_ty.mutbl = ast::Mutability::Mut;492 }493 _ => {494 panic!("unsupported type: {:?}", ty);495 }496 }497 ty498 }499500 // Generate `autodiff` intrinsic call501 // ```502 // std::intrinsics::autodiff(source as fn(..) -> .., diff, (args))503 // ```504 fn call_autodiff(505 ecx: &ExtCtxt<'_>,506 primal: Ident,507 diff: Ident,508 span: Span,509 p_sig: &FnSig,510 d_sig: &FnSig,511 generics: &Generics,512 is_impl: bool,513 ) -> rustc_ast::Stmt {514 let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);515516 let self_ty = || ecx.ty_path(ast::Path::from_ident(Ident::with_dummy_span(kw::SelfUpper)));517 let fn_ptr_params: ThinVec<ast::Param> = p_sig518 .decl519 .inputs520 .iter()521 .map(|param| {522 let ty = match ¶m.ty.kind {523 TyKind::ImplicitSelf => self_ty(),524 TyKind::Ref(lt, mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => ecx.ty(525 span,526 TyKind::Ref(lt.clone(), ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }),527 ),528 TyKind::Ptr(mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => {529 ecx.ty(span, TyKind::Ptr(ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }))530 }531 _ => param.ty.clone(),532 };533 ast::Param {534 attrs: ast::AttrVec::new(),535 ty,536 pat: Box::new(ecx.pat_wild(span)),537 id: ast::DUMMY_NODE_ID,538 span,539 is_placeholder: false,540 }541 })542 .collect();543 let fn_ptr_ty = ecx.ty(544 span,545 TyKind::FnPtr(Box::new(ast::FnPtrTy {546 safety: p_sig.header.safety,547 ext: p_sig.header.ext,548 generic_params: ThinVec::new(),549 decl: Box::new(ast::FnDecl {550 inputs: fn_ptr_params,551 output: p_sig.decl.output.clone(),552 }),553 decl_span: span,554 })),555 );556 let primal_fn_ptr = ecx.expr(span, ast::ExprKind::Cast(primal_path_expr, fn_ptr_ty));557558 let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);559560 let tuple_expr = ecx.expr_tuple(561 span,562 d_sig563 .decl564 .inputs565 .iter()566 .map(|arg| match arg.pat.kind {567 PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),568 _ => todo!(),569 })570 .collect::<ThinVec<_>>()571 .into(),572 );573574 let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]);575 let enzyme_path = ecx.path(span, enzyme_path_idents);576 let call_expr = ecx.expr_call(577 span,578 ecx.expr_path(enzyme_path),579 vec![primal_fn_ptr, diff_path_expr, tuple_expr].into(),580 );581582 ecx.stmt_expr(call_expr)583 }584585 // Generate turbofish expression from fn name and generics586 // Given `foo` and `<A, B, C>` params, gen `foo::<A, B, C>`587 // We use this expression when passing primal and diff function to the autodiff intrinsic588 fn gen_turbofish_expr(589 ecx: &ExtCtxt<'_>,590 ident: Ident,591 generics: &Generics,592 span: Span,593 is_impl: bool,594 ) -> Box<ast::Expr> {595 let generic_args = generics596 .params597 .iter()598 .filter_map(|p| match &p.kind {599 GenericParamKind::Type { .. } => {600 let path = ast::Path::from_ident(p.ident);601 let ty = ecx.ty_path(path);602 Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))603 }604 GenericParamKind::Const { .. } => {605 let expr = ecx.expr_path(ast::Path::from_ident(p.ident));606 let anon_const = AnonConst {607 id: ast::DUMMY_NODE_ID,608 value: expr,609 mgca_disambiguation: MgcaDisambiguation::Direct,610 };611 Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))612 }613 GenericParamKind::Lifetime { .. } => None,614 })615 .collect::<ThinVec<_>>();616617 let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args };618619 let segment = PathSegment {620 ident,621 id: ast::DUMMY_NODE_ID,622 args: Some(Box::new(GenericArgs::AngleBracketed(args))),623 };624625 let segments = if is_impl {626 thin_vec![627 PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None },628 segment,629 ]630 } else {631 thin_vec![segment]632 };633634 let path = Path { span, segments, tokens: None };635636 ecx.expr_path(path)637 }638639 // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must640 // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.641 // Active arguments must be scalars. Their shadow argument is added to the return type (and will be642 // zero-initialized by Enzyme).643 // Each argument of the primal function (and the return type if existing) must be annotated with an644 // activity.645 //646 // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or647 // both), we emit an error and return the original signature. This allows us to continue parsing.648 // FIXME(Sa4dUs): make individual activities' span available so errors649 // can point to only the activity instead of the entire attribute650 fn gen_enzyme_decl(651 ecx: &ExtCtxt<'_>,652 sig: &ast::FnSig,653 x: &RustcAutodiff,654 span: Span,655 ) -> ast::FnSig {656 let dcx = ecx.sess.dcx();657 let has_ret = has_ret(&sig.decl.output);658 let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };659 let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };660 if sig_args != num_activities {661 dcx.emit_err(errors::AutoDiffInvalidNumberActivities {662 span,663 expected: sig_args,664 found: num_activities,665 });666 // This is not the right signature, but we can continue parsing.667 return sig.clone();668 }669 assert!(sig.decl.inputs.len() == x.input_activity.len());670 assert!(has_ret == x.has_ret_activity());671 let mut d_decl = sig.decl.clone();672 let mut d_inputs = Vec::new();673 let mut new_inputs = Vec::new();674 let mut idents = Vec::new();675 let mut act_ret = ThinVec::new();676677 // We have two loops, a first one just to check the activities and types and possibly report678 // multiple errors in one compilation session.679 let mut errors = false;680 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {681 if !valid_input_activity(x.mode, *activity) {682 dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {683 span,684 mode: x.mode.to_string(),685 act: activity.to_string(),686 });687 errors = true;688 }689 if !valid_ty_for_activity(&arg.ty, *activity) {690 dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {691 span: arg.ty.span,692 act: activity.to_string(),693 });694 errors = true;695 }696 }697698 if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {699 dcx.emit_err(errors::AutoDiffInvalidRetAct {700 span,701 mode: x.mode.to_string(),702 act: x.ret_activity.to_string(),703 });704 // We don't set `errors = true` to avoid annoying type errors relative705 // to the expanded macro type signature706 }707708 if errors {709 // This is not the right signature, but we can continue parsing.710 return sig.clone();711 }712713 let unsafe_activities = x714 .input_activity715 .iter()716 .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));717 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {718 d_inputs.push(arg.clone());719 match activity {720 DiffActivity::Active => {721 act_ret.push(arg.ty.clone());722 // if width =/= 1, then push [arg.ty; width] to act_ret723 }724 DiffActivity::ActiveOnly => {725 // We will add the active scalar to the return type.726 // This is handled later.727 }728 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {729 for i in 0..x.width {730 let mut shadow_arg = arg.clone();731 // We += into the shadow in reverse mode.732 shadow_arg.ty = Box::new(assure_mut_ref(&arg.ty));733 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {734 ident.name735 } else {736 debug!("{:#?}", &shadow_arg.pat);737 panic!("not an ident?");738 };739 let name: String = format!("d{}_{}", old_name, i);740 new_inputs.push(name.clone());741 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);742 shadow_arg.pat = Box::new(ast::Pat {743 id: ast::DUMMY_NODE_ID,744 kind: PatKind::Ident(BindingMode::NONE, ident, None),745 span: shadow_arg.pat.span,746 tokens: shadow_arg.pat.tokens.clone(),747 });748 d_inputs.push(shadow_arg.clone());749 }750 }751 DiffActivity::Dual752 | DiffActivity::DualOnly753 | DiffActivity::Dualv754 | DiffActivity::DualvOnly => {755 // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause756 // Enzyme to not expect N arguments, but one argument (which is instead larger).757 let iterations =758 if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {759 1760 } else {761 x.width762 };763 for i in 0..iterations {764 let mut shadow_arg = arg.clone();765 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {766 ident.name767 } else {768 debug!("{:#?}", &shadow_arg.pat);769 panic!("not an ident?");770 };771 let name: String = format!("b{}_{}", old_name, i);772 new_inputs.push(name.clone());773 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);774 shadow_arg.pat = Box::new(ast::Pat {775 id: ast::DUMMY_NODE_ID,776 kind: PatKind::Ident(BindingMode::NONE, ident, None),777 span: shadow_arg.pat.span,778 tokens: shadow_arg.pat.tokens.clone(),779 });780 d_inputs.push(shadow_arg.clone());781 }782 }783 DiffActivity::Const => {784 // Nothing to do here.785 }786 DiffActivity::None | DiffActivity::FakeActivitySize(_) => {787 panic!("Should not happen");788 }789 }790 if let PatKind::Ident(_, ident, _) = arg.pat.kind {791 idents.push(ident.clone());792 } else {793 panic!("not an ident?");794 }795 }796797 let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;798 if active_only_ret {799 assert!(x.mode.is_rev());800 }801802 // If we return a scalar in the primal and the scalar is active,803 // then add it as last arg to the inputs.804 if x.mode.is_rev() {805 match x.ret_activity {806 DiffActivity::Active | DiffActivity::ActiveOnly => {807 let ty = match d_decl.output {808 FnRetTy::Ty(ref ty) => ty.clone(),809 FnRetTy::Default(span) => {810 panic!("Did not expect Default ret ty: {:?}", span);811 }812 };813 let name = "dret".to_string();814 let ident = Ident::from_str_and_span(&name, ty.span);815 let shadow_arg = ast::Param {816 attrs: ThinVec::new(),817 ty: ty.clone(),818 pat: Box::new(ast::Pat {819 id: ast::DUMMY_NODE_ID,820 kind: PatKind::Ident(BindingMode::NONE, ident, None),821 span: ty.span,822 tokens: None,823 }),824 id: ast::DUMMY_NODE_ID,825 span: ty.span,826 is_placeholder: false,827 };828 d_inputs.push(shadow_arg);829 new_inputs.push(name);830 }831 _ => {}832 }833 }834 d_decl.inputs = d_inputs.into();835836 if x.mode.is_fwd() {837 let ty = match d_decl.output {838 FnRetTy::Ty(ref ty) => ty.clone(),839 FnRetTy::Default(span) => {840 // We want to return std::hint::black_box(()).841 let kind = TyKind::Tup(ThinVec::new());842 let ty = Box::new(rustc_ast::Ty {843 kind,844 id: ast::DUMMY_NODE_ID,845 span,846 tokens: None,847 });848 d_decl.output = FnRetTy::Ty(ty.clone());849 assert!(matches!(x.ret_activity, DiffActivity::None));850 // this won't be used below, so any type would be fine.851 ty852 }853 };854855 if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {856 let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {857 // Dual can only be used for f32/f64 ret.858 // In that case we return now a tuple with two floats.859 TyKind::Tup(thin_vec![ty.clone(), ty.clone()])860 } else {861 // We have to return [T; width+1], +1 for the primal return.862 let anon_const = rustc_ast::AnonConst {863 id: ast::DUMMY_NODE_ID,864 value: ecx.expr_usize(span, 1 + x.width as usize),865 mgca_disambiguation: MgcaDisambiguation::Direct,866 };867 TyKind::Array(ty.clone(), anon_const)868 };869 let ty = Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });870 d_decl.output = FnRetTy::Ty(ty);871 }872 if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {873 // No need to change the return type,874 // we will just return the shadow in place of the primal return.875 // However, if we have a width > 1, then we don't return -> T, but -> [T; width]876 if x.width > 1 {877 let anon_const = rustc_ast::AnonConst {878 id: ast::DUMMY_NODE_ID,879 value: ecx.expr_usize(span, x.width as usize),880 mgca_disambiguation: MgcaDisambiguation::Direct,881 };882 let kind = TyKind::Array(ty.clone(), anon_const);883 let ty =884 Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });885 d_decl.output = FnRetTy::Ty(ty);886 }887 }888 }889890 // If we use ActiveOnly, drop the original return value.891 d_decl.output =892 if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };893894 trace!("act_ret: {:?}", act_ret);895896 // If we have an active input scalar, add it's gradient to the897 // return type. This might require changing the return type to a898 // tuple.899 if act_ret.len() > 0 {900 let ret_ty = match d_decl.output {901 FnRetTy::Ty(ref ty) => {902 if !active_only_ret {903 act_ret.insert(0, ty.clone());904 }905 let kind = TyKind::Tup(act_ret);906 Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })907 }908 FnRetTy::Default(span) => {909 if act_ret.len() == 1 {910 act_ret[0].clone()911 } else {912 let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());913 Box::new(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })914 }915 }916 };917 d_decl.output = FnRetTy::Ty(ret_ty);918 }919920 let mut d_header = sig.header.clone();921 if unsafe_activities {922 d_header.safety = rustc_ast::Safety::Unsafe(span);923 }924 let d_sig = FnSig { header: d_header, decl: d_decl, span };925 trace!("Generated signature: {:?}", d_sig);926 d_sig927 }928}929930pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};
Code quality findings 46
Warning: '.unwrap()' will panic on None/Err variants. Prefer using pattern matching (match, if let), combinators (map, and_then), or the '?' operator for robust error handling.
warning
correctness
unwrap-usage
let segments = &x.meta_item().unwrap().path.segments;
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
segments[0].ident
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
let width = if let [_, x, ..] = &meta_item[..]
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
span: meta_item[1].span(),
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
for x in &meta_item[first_activity..] {
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
&& let Some(width) = width(&meta_item_vec[1])
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
for t in meta_item_vec.clone()[start_position..].iter() {
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
first_ident(&meta_item_vec[0]),
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
ident: first_ident(&meta_item_vec[0]),
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
// if width =/= 1, then push [arg.ty; width] to act_ret
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
// We have to return [T; width+1], +1 for the primal return.
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning
correctness
unchecked-indexing
act_ret[0].clone()
Info: Wildcard imports (`use some::path::*;`) can obscure the origin of names and lead to conflicts. Prefer importing specific items explicitly.
info
maintainability
wildcard-import
use rustc_ast::tokenstream::*;
Info: Wildcard imports (`use some::path::*;`) can obscure the origin of names and lead to conflicts. Prefer importing specific items explicitly.
info
maintainability
wildcard-import
use rustc_ast::visit::AssocCtxt::*;
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
match lit.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
match &iitem.kind {
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info
performance
push-without-reserve
Ok(x) => activities.push(x),
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
let Some((vis, sig, primal, generics, is_impl)) = (match &item {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
Annotatable::Stmt(stmt) => match &stmt.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
match &assoc_item.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
let mode_symbol = match mode {
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info
performance
push-without-reserve
ts.push(TokenTree::Token(t, Spacing::Joint));
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info
performance
push-without-reserve
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info
performance
clone-in-loop
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info
performance
clone-in-loop
for t in meta_item_vec.clone()[start_position..].iter() {
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info
performance
push-without-reserve
ts.push(TokenTree::Token(t, Spacing::Joint));
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info
performance
push-without-reserve
ts.push(TokenTree::Token(comma, Spacing::Alone));
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
match (attr, item) {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
match stmt.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
match ty.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
let ty = match ¶m.ty.kind {
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info
correctness
match-wildcard
.map(|arg| match arg.pat.kind {
Maintainability Info: `todo!()` or `unimplemented!()` macros indicate incomplete code paths that will panic at runtime if reached. Ensure these are replaced with actual logic before production use.
info
correctness
todo-unimplemented
_ => todo!(),
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info
performance
clone-in-loop
let mut d_decl = sig.decl.clone();
Performance Info: Calling .to_string() (especially on &str) allocates a new String. If done repeatedly in loops, consider alternatives like working with &str or using crates like `itoa`/`ryu` for number-to-string conversion.
info
performance
to-string-in-loop
mode: x.mode.to_string(),
Performance Info: Calling .to_string() (especially on &str) allocates a new String. If done repeatedly in loops, consider alternatives like working with &str or using crates like `itoa`/`ryu` for number-to-string conversion.
info
performance
to-string-in-loop
act: activity.to_string(),
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info
performance
clone-in-loop
return sig.clone();
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info
performance
push-without-reserve
d_inputs.push(arg.clone());
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info
performance
clone-in-loop
d_inputs.push(arg.clone());
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info
performance
push-without-reserve
act_ret.push(arg.ty.clone());
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info
performance
clone-in-loop
act_ret.push(arg.ty.clone());
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info
performance
clone-in-loop
let mut shadow_arg = arg.clone();
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info
performance
clone-in-loop
let mut shadow_arg = arg.clone();
Performance Info: Frequent cloning, especially of Strings, Vecs, or other heap-allocated types inside loops, can be expensive. Consider using references/borrowing where possible.
info
performance
clone-in-loop
new_inputs.push(name.clone());
Performance Info: Calling .push() repeatedly inside a loop without prior capacity reservation can lead to multiple reallocations. Consider using `Vec::with_capacity(n)` or `vec.reserve(n)` if the approximate number of elements is known.
info
performance
push-without-reserve
new_inputs.push(name.clone());