/algorithms_for_tapenade/fortran_code/qr_b.f
http://github.com/b45ch1/hpsc_hanoi_2009_walter · FORTRAN Legacy · 135 lines · 95 code · 0 blank · 40 comment · 0 complexity · 24f491681f37aa519c8f2dacbe6ea13d MD5 · raw file
- C Generated by TAPENADE (INRIA, Tropics team)
- C Tapenade 3.4 (r3375) - 10 Feb 2010 15:08
- C
- C Differentiation of qr in reverse (adjoint) mode:
- C gradient of useful results: r qt
- C with respect to varying inputs: qt a
- SUBROUTINE QR_B(a, ab, qt, qtb, r, rb, na)
- IMPLICIT NONE
- INTEGER*4 na
- REAL*8 a(na*na), qt(na*na), r(na*na)
- REAL*8 ab(na*na), qtb(na*na), rb(na*na)
- REAL*8 tmp, at, bt, rt, c, s, rnk, qtnk
- REAL*8 atb, btb, rtb, cb, sb, rnkb, qtnkb
- INTEGER*4 n, m, k
- REAL*8 tmp0
- REAL*8 tmp1
- INTEGER ad_from
- REAL*8 tmp0b
- REAL*8 tempb
- INTEGER ii1
- INTRINSIC SQRT
- REAL*8 tmp1b
- C PREPARE Q AND R
- DO n=0,na-1
- DO m=0,na-1
- IF (n .EQ. m) THEN
- tmp = 1.0
- ELSE
- tmp = 0.0
- END IF
- qt(n*na+m+1) = tmp
- ENDDO
- ENDDO
- DO n=1,na*na
- r(n) = a(n)
- ENDDO
- C MAIN ALGORITHM
- DO n=0,na-1
- ad_from = n + 1
- DO m=ad_from,na-1
- CALL PUSHREAL8(at)
- at = r(n*na+n+1)
- CALL PUSHREAL8(bt)
- bt = r(m*na+n+1)
- CALL PUSHREAL8(rt)
- rt = SQRT(at*at + bt*bt)
- CALL PUSHREAL8(c)
- c = at/rt
- CALL PUSHREAL8(s)
- s = bt/rt
- DO k=1,na
- CALL PUSHREAL8(rnk)
- C UPDATE R
- rnk = r(n*na+k)
- tmp0 = c*rnk + s*r(m*na+k)
- CALL PUSHREAL8(r(n*na+k))
- r(n*na+k) = tmp0
- CALL PUSHREAL8(r(m*na+k))
- r(m*na+k) = -(s*rnk) + c*r(m*na+k)
- CALL PUSHREAL8(qtnk)
- C UPDATE Q
- qtnk = qt(n*na+k)
- tmp1 = c*qtnk + s*qt(m*na+k)
- CALL PUSHREAL8(qt(n*na+k))
- qt(n*na+k) = tmp1
- CALL PUSHREAL8(qt(m*na+k))
- qt(m*na+k) = -(s*qtnk) + c*qt(m*na+k)
- ENDDO
- ENDDO
- CALL PUSHINTEGER4(ad_from)
- ENDDO
- DO n=na-1,0,-1
- CALL POPINTEGER4(ad_from)
- DO m=na-1,ad_from,-1
- sb = 0.0
- cb = 0.0
- DO k=na,1,-1
- CALL POPREAL8(qt(m*na+k))
- cb = cb + qt(m*na+k)*qtb(m*na+k)
- sb = sb - qtnk*qtb(m*na+k)
- qtnkb = -(s*qtb(m*na+k))
- qtb(m*na+k) = c*qtb(m*na+k)
- CALL POPREAL8(qt(n*na+k))
- tmp1b = qtb(n*na+k)
- qtb(n*na+k) = 0.0
- cb = cb + qtnk*tmp1b
- qtnkb = qtnkb + c*tmp1b
- sb = sb + qt(m*na+k)*tmp1b - rnk*rb(m*na+k)
- qtb(m*na+k) = qtb(m*na+k) + s*tmp1b
- CALL POPREAL8(qtnk)
- qtb(n*na+k) = qtb(n*na+k) + qtnkb
- CALL POPREAL8(r(m*na+k))
- cb = cb + r(m*na+k)*rb(m*na+k)
- rnkb = -(s*rb(m*na+k))
- rb(m*na+k) = c*rb(m*na+k)
- CALL POPREAL8(r(n*na+k))
- tmp0b = rb(n*na+k)
- rb(n*na+k) = 0.0
- cb = cb + rnk*tmp0b
- rnkb = rnkb + c*tmp0b
- sb = sb + r(m*na+k)*tmp0b
- rb(m*na+k) = rb(m*na+k) + s*tmp0b
- CALL POPREAL8(rnk)
- rb(n*na+k) = rb(n*na+k) + rnkb
- ENDDO
- CALL POPREAL8(s)
- rtb = -(at*cb/rt**2) - bt*sb/rt**2
- IF (at**2 + bt**2 .EQ. 0.0) THEN
- tempb = 0.0
- ELSE
- tempb = rtb/(2.0*SQRT(at**2+bt**2))
- END IF
- btb = 2*bt*tempb + sb/rt
- CALL POPREAL8(c)
- atb = 2*at*tempb + cb/rt
- CALL POPREAL8(rt)
- CALL POPREAL8(bt)
- rb(m*na+n+1) = rb(m*na+n+1) + btb
- CALL POPREAL8(at)
- rb(n*na+n+1) = rb(n*na+n+1) + atb
- ENDDO
- ENDDO
- DO ii1=1,na*na
- ab(ii1) = 0.0
- ENDDO
- DO n=na*na,1,-1
- ab(n) = ab(n) + rb(n)
- rb(n) = 0.0
- ENDDO
- DO n=na-1,0,-1
- DO m=na-1,0,-1
- qtb(n*na+m+1) = 0.0
- ENDDO
- ENDDO
- END