1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,2//! we create an `RustcAutodiff` which contains the source and target function names. The source3//! is the function to which the autodiff attribute is applied, and the target is the function4//! getting generated by us (with a name given by the user as the first autodiff arg).56use std::fmt::{self, Display, Formatter};7use std::str::FromStr;89use rustc_span::{Symbol, sym};1011use crate::expand::{Decodable, Encodable, StableHash};12use crate::{Ty, TyKind};1314/// Forward and Reverse Mode are well known names for automatic differentiation implementations.15/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants16/// are a hack to support higher order derivatives. We need to compute first order derivatives17/// before we compute second order derivatives, otherwise we would differentiate our placeholder18/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,19/// as it's already done in the C++ and Julia frontend of Enzyme.20///21/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and22/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.23#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)]24pub enum DiffMode {25 /// No autodiff is applied (used during error handling).26 Error,27 /// The primal function which we will differentiate.28 Source,29 /// The target function, to be created using forward mode AD.30 Forward,31 /// The target function, to be created using reverse mode AD.32 Reverse,33}3435impl DiffMode {36 pub fn all_modes() -> &'static [Symbol] {37 &[sym::Source, sym::Forward, sym::Reverse]38 }39}4041/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.42/// However, under forward mode we overwrite the previous shadow value, while for reverse mode43/// we add to the previous shadow value. To not surprise users, we picked different names.44/// Dual numbers is also a quite well known name for forward mode AD types.45#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, StableHash)]46pub enum DiffActivity {47 /// Implicit or Explicit () return type, so a special case of Const.48 None,49 /// Don't compute derivatives with respect to this input/output.50 Const,51 /// Reverse Mode, Compute derivatives for this scalar input/output.52 Active,53 /// Reverse Mode, Compute derivatives for this scalar output, but don't compute54 /// the original return value.55 ActiveOnly,56 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument57 /// with it.58 Dual,59 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument60 /// with it. It expects the shadow argument to be `width` times larger than the original61 /// input/output.62 Dualv,63 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument64 /// with it. Drop the code which updates the original input/output for maximum performance.65 DualOnly,66 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument67 /// with it. Drop the code which updates the original input/output for maximum performance.68 /// It expects the shadow argument to be `width` times larger than the original input/output.69 DualvOnly,70 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.71 Duplicated,72 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.73 /// Drop the code which updates the original input for maximum performance.74 DuplicatedOnly,75 /// All Integers must be Const, but these are used to mark the integer which represents the76 /// length of a slice/vec. This is used for safety checks on slices.77 /// The integer (if given) specifies the size of the slice element in bytes.78 FakeActivitySize(Option<u32>),79}8081impl DiffActivity {82 pub fn is_dual_or_const(&self) -> bool {83 use DiffActivity::*;84 matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)85 }8687 pub fn all_activities() -> &'static [Symbol] {88 &[89 sym::None,90 sym::Active,91 sym::ActiveOnly,92 sym::Const,93 sym::Dual,94 sym::Dualv,95 sym::DualOnly,96 sym::DualvOnly,97 sym::Duplicated,98 sym::DuplicatedOnly,99 ]100 }101}102103impl DiffMode {104 pub fn is_rev(&self) -> bool {105 matches!(self, DiffMode::Reverse)106 }107 pub fn is_fwd(&self) -> bool {108 matches!(self, DiffMode::Forward)109 }110}111112impl Display for DiffMode {113 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {114 match self {115 DiffMode::Error => write!(f, "Error"),116 DiffMode::Source => write!(f, "Source"),117 DiffMode::Forward => write!(f, "Forward"),118 DiffMode::Reverse => write!(f, "Reverse"),119 }120 }121}122123/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).124/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).125/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.126/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,127/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.128pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {129 if activity == DiffActivity::None {130 // Only valid if primal returns (), but we can't check that here.131 return true;132 }133 match mode {134 DiffMode::Error => false,135 DiffMode::Source => false,136 DiffMode::Forward => activity.is_dual_or_const(),137 DiffMode::Reverse => {138 activity == DiffActivity::Const139 || activity == DiffActivity::Active140 || activity == DiffActivity::ActiveOnly141 }142 }143}144145/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value146/// for the given argument, but we generally can't know the size of such a type.147/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,148/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value149/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent150/// users here from marking scalars as Duplicated, due to type aliases.151pub fn valid_ty_for_activity(ty: &Box<Ty>, activity: DiffActivity) -> bool {152 use DiffActivity::*;153 // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.154 // Dual variants also support all types.155 if activity.is_dual_or_const() {156 return true;157 }158 // FIXME(ZuseZ4) We should make this more robust to also159 // handle type aliases. Once that is done, we can be more restrictive here.160 if matches!(activity, Active | ActiveOnly) {161 return true;162 }163 matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))164 && matches!(activity, Duplicated | DuplicatedOnly)165}166pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {167 use DiffActivity::*;168 return match mode {169 DiffMode::Error => false,170 DiffMode::Source => false,171 DiffMode::Forward => activity.is_dual_or_const(),172 DiffMode::Reverse => {173 matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)174 }175 };176}177178impl Display for DiffActivity {179 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {180 match self {181 DiffActivity::None => write!(f, "None"),182 DiffActivity::Const => write!(f, "Const"),183 DiffActivity::Active => write!(f, "Active"),184 DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),185 DiffActivity::Dual => write!(f, "Dual"),186 DiffActivity::Dualv => write!(f, "Dualv"),187 DiffActivity::DualOnly => write!(f, "DualOnly"),188 DiffActivity::DualvOnly => write!(f, "DualvOnly"),189 DiffActivity::Duplicated => write!(f, "Duplicated"),190 DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),191 DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),192 }193 }194}195196impl FromStr for DiffMode {197 type Err = ();198199 fn from_str(s: &str) -> Result<DiffMode, ()> {200 match s {201 "Error" => Ok(DiffMode::Error),202 "Source" => Ok(DiffMode::Source),203 "Forward" => Ok(DiffMode::Forward),204 "Reverse" => Ok(DiffMode::Reverse),205 _ => Err(()),206 }207 }208}209impl FromStr for DiffActivity {210 type Err = ();211212 fn from_str(s: &str) -> Result<DiffActivity, ()> {213 match s {214 "None" => Ok(DiffActivity::None),215 "Active" => Ok(DiffActivity::Active),216 "ActiveOnly" => Ok(DiffActivity::ActiveOnly),217 "Const" => Ok(DiffActivity::Const),218 "Dual" => Ok(DiffActivity::Dual),219 "Dualv" => Ok(DiffActivity::Dualv),220 "DualOnly" => Ok(DiffActivity::DualOnly),221 "DualvOnly" => Ok(DiffActivity::DualvOnly),222 "Duplicated" => Ok(DiffActivity::Duplicated),223 "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),224 _ => Err(()),225 }226 }227}