compiler/rustc_ast/src/expand/autodiff_attrs.rs RUST 228 lines View on github.com → Search inside
1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,2//! we create an `RustcAutodiff` which contains the source and target function names. The source3//! is the function to which the autodiff attribute is applied, and the target is the function4//! getting generated by us (with a name given by the user as the first autodiff arg).56use std::fmt::{self, Display, Formatter};7use std::str::FromStr;89use rustc_span::{Symbol, sym};1011use crate::expand::{Decodable, Encodable, StableHash};12use crate::{Ty, TyKind};1314/// Forward and Reverse Mode are well known names for automatic differentiation implementations.15/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants16/// are a hack to support higher order derivatives. We need to compute first order derivatives17/// before we compute second order derivatives, otherwise we would differentiate our placeholder18/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,19/// as it's already done in the C++ and Julia frontend of Enzyme.20///21/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and22/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.23#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)]24pub enum DiffMode {25    /// No autodiff is applied (used during error handling).26    Error,27    /// The primal function which we will differentiate.28    Source,29    /// The target function, to be created using forward mode AD.30    Forward,31    /// The target function, to be created using reverse mode AD.32    Reverse,33}3435impl DiffMode {36    pub fn all_modes() -> &'static [Symbol] {37        &[sym::Source, sym::Forward, sym::Reverse]38    }39}4041/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.42/// However, under forward mode we overwrite the previous shadow value, while for reverse mode43/// we add to the previous shadow value. To not surprise users, we picked different names.44/// Dual numbers is also a quite well known name for forward mode AD types.45#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)]46pub enum DiffActivity {47    /// Implicit or Explicit () return type, so a special case of Const.48    None,49    /// Don't compute derivatives with respect to this input/output.50    Const,51    /// Reverse Mode, Compute derivatives for this scalar input/output.52    Active,53    /// Reverse Mode, Compute derivatives for this scalar output, but don't compute54    /// the original return value.55    ActiveOnly,56    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument57    /// with it.58    Dual,59    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument60    /// with it. It expects the shadow argument to be `width` times larger than the original61    /// input/output.62    Dualv,63    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument64    /// with it. Drop the code which updates the original input/output for maximum performance.65    DualOnly,66    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument67    /// with it. Drop the code which updates the original input/output for maximum performance.68    /// It expects the shadow argument to be `width` times larger than the original input/output.69    DualvOnly,70    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.71    Duplicated,72    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.73    /// Drop the code which updates the original input for maximum performance.74    DuplicatedOnly,75    /// All Integers must be Const, but these are used to mark the integer which represents the76    /// length of a slice/vec. This is used for safety checks on slices.77    /// The integer (if given) specifies the size of the slice element in bytes.78    FakeActivitySize(Option<u32>),79}8081impl DiffActivity {82    pub fn is_dual_or_const(&self) -> bool {83        use DiffActivity::*;84        matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)85    }8687    pub fn all_activities() -> &'static [Symbol] {88        &[89            sym::None,90            sym::Active,91            sym::ActiveOnly,92            sym::Const,93            sym::Dual,94            sym::Dualv,95            sym::DualOnly,96            sym::DualvOnly,97            sym::Duplicated,98            sym::DuplicatedOnly,99        ]100    }101}102103impl DiffMode {104    pub fn is_rev(&self) -> bool {105        matches!(self, DiffMode::Reverse)106    }107    pub fn is_fwd(&self) -> bool {108        matches!(self, DiffMode::Forward)109    }110}111112impl Display for DiffMode {113    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {114        match self {115            DiffMode::Error => write!(f, "Error"),116            DiffMode::Source => write!(f, "Source"),117            DiffMode::Forward => write!(f, "Forward"),118            DiffMode::Reverse => write!(f, "Reverse"),119        }120    }121}122123/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).124/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).125/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.126/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,127/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.128pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {129    if activity == DiffActivity::None {130        // Only valid if primal returns (), but we can't check that here.131        return true;132    }133    match mode {134        DiffMode::Error => false,135        DiffMode::Source => false,136        DiffMode::Forward => activity.is_dual_or_const(),137        DiffMode::Reverse => {138            activity == DiffActivity::Const139                || activity == DiffActivity::Active140                || activity == DiffActivity::ActiveOnly141        }142    }143}144145/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value146/// for the given argument, but we generally can't know the size of such a type.147/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,148/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value149/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent150/// users here from marking scalars as Duplicated, due to type aliases.151pub fn valid_ty_for_activity(ty: &Box<Ty>, activity: DiffActivity) -> bool {152    use DiffActivity::*;153    // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.154    // Dual variants also support all types.155    if activity.is_dual_or_const() {156        return true;157    }158    // FIXME(ZuseZ4) We should make this more robust to also159    // handle type aliases. Once that is done, we can be more restrictive here.160    if matches!(activity, Active | ActiveOnly) {161        return true;162    }163    matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))164        && matches!(activity, Duplicated | DuplicatedOnly)165}166pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {167    use DiffActivity::*;168    return match mode {169        DiffMode::Error => false,170        DiffMode::Source => false,171        DiffMode::Forward => activity.is_dual_or_const(),172        DiffMode::Reverse => {173            matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)174        }175    };176}177178impl Display for DiffActivity {179    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {180        match self {181            DiffActivity::None => write!(f, "None"),182            DiffActivity::Const => write!(f, "Const"),183            DiffActivity::Active => write!(f, "Active"),184            DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),185            DiffActivity::Dual => write!(f, "Dual"),186            DiffActivity::Dualv => write!(f, "Dualv"),187            DiffActivity::DualOnly => write!(f, "DualOnly"),188            DiffActivity::DualvOnly => write!(f, "DualvOnly"),189            DiffActivity::Duplicated => write!(f, "Duplicated"),190            DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),191            DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),192        }193    }194}195196impl FromStr for DiffMode {197    type Err = ();198199    fn from_str(s: &str) -> Result<DiffMode, ()> {200        match s {201            "Error" => Ok(DiffMode::Error),202            "Source" => Ok(DiffMode::Source),203            "Forward" => Ok(DiffMode::Forward),204            "Reverse" => Ok(DiffMode::Reverse),205            _ => Err(()),206        }207    }208}209impl FromStr for DiffActivity {210    type Err = ();211212    fn from_str(s: &str) -> Result<DiffActivity, ()> {213        match s {214            "None" => Ok(DiffActivity::None),215            "Active" => Ok(DiffActivity::Active),216            "ActiveOnly" => Ok(DiffActivity::ActiveOnly),217            "Const" => Ok(DiffActivity::Const),218            "Dual" => Ok(DiffActivity::Dual),219            "Dualv" => Ok(DiffActivity::Dualv),220            "DualOnly" => Ok(DiffActivity::DualOnly),221            "DualvOnly" => Ok(DiffActivity::DualvOnly),222            "Duplicated" => Ok(DiffActivity::Duplicated),223            "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),224            _ => Err(()),225        }226    }227}

Code quality findings 8

Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
pub fn all_modes() -> &'static [Symbol] {
Warning: Direct indexing (e.g., `vec[i]`, `slice[i]`) panics on out-of-bounds access. Prefer using `.get(index)` or `.get_mut(index)` which return Option<&T>/Option<&mut T>.
warning correctness unchecked-indexing
pub fn all_activities() -> &'static [Symbol] {
Info: Wildcard imports (`use some::path::*;`) can obscure the origin of names and lead to conflicts. Prefer importing specific items explicitly.
info maintainability wildcard-import
use DiffActivity::*;
Info: Use of raw pointers (*const T, *mut T) typically requires 'unsafe' blocks for dereferencing. Ensure usage is justified (FFI, low-level optimizations) and memory safety is manually upheld.
info safety raw-pointer
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
Info: Wildcard imports (`use some::path::*;`) can obscure the origin of names and lead to conflicts. Prefer importing specific items explicitly.
info maintainability wildcard-import
use DiffActivity::*;
Info: Wildcard imports (`use some::path::*;`) can obscure the origin of names and lead to conflicts. Prefer importing specific items explicitly.
info maintainability wildcard-import
use DiffActivity::*;
Info: Ensure 'match' statements are exhaustive. If matching on enums, consider adding a wildcard arm `_ => {}` only if necessary and intentional, as it suppresses warnings about unhandled variants.
info correctness match-wildcard
match s {

Get this view in your editor

Same data, no extra tab — call code_get_file + code_get_findings over MCP from Claude/Cursor/Copilot.