/oauth/oauth.go

https://code.google.com/p/goauth2/ · Go · 398 lines · 258 code · 37 blank · 103 comment · 79 complexity · 242822d19c909c0864ea5aa79fe0aad1 MD5 · raw file

  1. // Copyright 2011 The goauth2 Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // The oauth package provides support for making
  5. // OAuth2-authenticated HTTP requests.
  6. //
  7. // Example usage:
  8. //
  9. // // Specify your configuration. (typically as a global variable)
  10. // var config = &oauth.Config{
  11. // ClientId: YOUR_CLIENT_ID,
  12. // ClientSecret: YOUR_CLIENT_SECRET,
  13. // Scope: "https://www.googleapis.com/auth/buzz",
  14. // AuthURL: "https://accounts.google.com/o/oauth2/auth",
  15. // TokenURL: "https://accounts.google.com/o/oauth2/token",
  16. // RedirectURL: "http://you.example.org/handler",
  17. // }
  18. //
  19. // // A landing page redirects to the OAuth provider to get the auth code.
  20. // func landing(w http.ResponseWriter, r *http.Request) {
  21. // http.Redirect(w, r, config.AuthCodeURL("foo"), http.StatusFound)
  22. // }
  23. //
  24. // // The user will be redirected back to this handler, that takes the
  25. // // "code" query parameter and Exchanges it for an access token.
  26. // func handler(w http.ResponseWriter, r *http.Request) {
  27. // t := &oauth.Transport{Config: config}
  28. // t.Exchange(r.FormValue("code"))
  29. // // The Transport now has a valid Token. Create an *http.Client
  30. // // with which we can make authenticated API requests.
  31. // c := t.Client()
  32. // c.Post(...)
  33. // // ...
  34. // // btw, r.FormValue("state") == "foo"
  35. // }
  36. //
  37. package oauth
  38. import (
  39. "encoding/json"
  40. "io/ioutil"
  41. "mime"
  42. "net/http"
  43. "net/url"
  44. "os"
  45. "strings"
  46. "time"
  47. )
  48. type OAuthError struct {
  49. prefix string
  50. msg string
  51. }
  52. func (oe OAuthError) Error() string {
  53. return "OAuthError: " + oe.prefix + ": " + oe.msg
  54. }
  55. // Cache specifies the methods that implement a Token cache.
  56. type Cache interface {
  57. Token() (*Token, error)
  58. PutToken(*Token) error
  59. }
  60. // CacheFile implements Cache. Its value is the name of the file in which
  61. // the Token is stored in JSON format.
  62. type CacheFile string
  63. func (f CacheFile) Token() (*Token, error) {
  64. file, err := os.Open(string(f))
  65. if err != nil {
  66. return nil, OAuthError{"CacheFile.Token", err.Error()}
  67. }
  68. defer file.Close()
  69. tok := &Token{}
  70. if err := json.NewDecoder(file).Decode(tok); err != nil {
  71. return nil, OAuthError{"CacheFile.Token", err.Error()}
  72. }
  73. return tok, nil
  74. }
  75. func (f CacheFile) PutToken(tok *Token) error {
  76. file, err := os.OpenFile(string(f), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
  77. if err != nil {
  78. return OAuthError{"CacheFile.PutToken", err.Error()}
  79. }
  80. if err := json.NewEncoder(file).Encode(tok); err != nil {
  81. file.Close()
  82. return OAuthError{"CacheFile.PutToken", err.Error()}
  83. }
  84. if err := file.Close(); err != nil {
  85. return OAuthError{"CacheFile.PutToken", err.Error()}
  86. }
  87. return nil
  88. }
  89. // Config is the configuration of an OAuth consumer.
  90. type Config struct {
  91. // ClientId is the OAuth client identifier used when communicating with
  92. // the configured OAuth provider.
  93. ClientId string
  94. // ClientSecret is the OAuth client secret used when communicating with
  95. // the configured OAuth provider.
  96. ClientSecret string
  97. // Scope identifies the level of access being requested. Multiple scope
  98. // values should be provided as a space-delimited string.
  99. Scope string
  100. // AuthURL is the URL the user will be directed to in order to grant
  101. // access.
  102. AuthURL string
  103. // TokenURL is the URL used to retrieve OAuth tokens.
  104. TokenURL string
  105. // RedirectURL is the URL to which the user will be returned after
  106. // granting (or denying) access.
  107. RedirectURL string
  108. // TokenCache allows tokens to be cached for subsequent requests.
  109. TokenCache Cache
  110. AccessType string // Optional, "online" (default) or "offline", no refresh token if "online"
  111. // ApprovalPrompt indicates whether the user should be
  112. // re-prompted for consent. If set to "auto" (default) the
  113. // user will be prompted only if they haven't previously
  114. // granted consent and the code can only be exchanged for an
  115. // access token.
  116. // If set to "force" the user will always be prompted, and the
  117. // code can be exchanged for a refresh token.
  118. ApprovalPrompt string
  119. }
  120. // Token contains an end-user's tokens.
  121. // This is the data you must store to persist authentication.
  122. type Token struct {
  123. AccessToken string
  124. RefreshToken string
  125. Expiry time.Time // If zero the token has no (known) expiry time.
  126. Extra map[string]string // May be nil.
  127. }
  128. func (t *Token) Expired() bool {
  129. if t.Expiry.IsZero() {
  130. return false
  131. }
  132. return t.Expiry.Before(time.Now())
  133. }
  134. // Transport implements http.RoundTripper. When configured with a valid
  135. // Config and Token it can be used to make authenticated HTTP requests.
  136. //
  137. // t := &oauth.Transport{config}
  138. // t.Exchange(code)
  139. // // t now contains a valid Token
  140. // r, _, err := t.Client().Get("http://example.org/url/requiring/auth")
  141. //
  142. // It will automatically refresh the Token if it can,
  143. // updating the supplied Token in place.
  144. type Transport struct {
  145. *Config
  146. *Token
  147. // Transport is the HTTP transport to use when making requests.
  148. // It will default to http.DefaultTransport if nil.
  149. // (It should never be an oauth.Transport.)
  150. Transport http.RoundTripper
  151. }
  152. // Client returns an *http.Client that makes OAuth-authenticated requests.
  153. func (t *Transport) Client() *http.Client {
  154. return &http.Client{Transport: t}
  155. }
  156. func (t *Transport) transport() http.RoundTripper {
  157. if t.Transport != nil {
  158. return t.Transport
  159. }
  160. return http.DefaultTransport
  161. }
  162. // AuthCodeURL returns a URL that the end-user should be redirected to,
  163. // so that they may obtain an authorization code.
  164. func (c *Config) AuthCodeURL(state string) string {
  165. url_, err := url.Parse(c.AuthURL)
  166. if err != nil {
  167. panic("AuthURL malformed: " + err.Error())
  168. }
  169. q := url.Values{
  170. "response_type": {"code"},
  171. "client_id": {c.ClientId},
  172. "redirect_uri": {c.RedirectURL},
  173. "scope": {c.Scope},
  174. "state": {state},
  175. "access_type": {c.AccessType},
  176. "approval_prompt": {c.ApprovalPrompt},
  177. }.Encode()
  178. if url_.RawQuery == "" {
  179. url_.RawQuery = q
  180. } else {
  181. url_.RawQuery += "&" + q
  182. }
  183. return url_.String()
  184. }
  185. // Exchange takes a code and gets access Token from the remote server.
  186. func (t *Transport) Exchange(code string) (*Token, error) {
  187. if t.Config == nil {
  188. return nil, OAuthError{"Exchange", "no Config supplied"}
  189. }
  190. // If the transport or the cache already has a token, it is
  191. // passed to `updateToken` to preserve existing refresh token.
  192. tok := t.Token
  193. if tok == nil && t.TokenCache != nil {
  194. tok, _ = t.TokenCache.Token()
  195. }
  196. if tok == nil {
  197. tok = new(Token)
  198. }
  199. err := t.updateToken(tok, url.Values{
  200. "grant_type": {"authorization_code"},
  201. "redirect_uri": {t.RedirectURL},
  202. "scope": {t.Scope},
  203. "code": {code},
  204. })
  205. if err != nil {
  206. return nil, err
  207. }
  208. t.Token = tok
  209. if t.TokenCache != nil {
  210. return tok, t.TokenCache.PutToken(tok)
  211. }
  212. return tok, nil
  213. }
  214. // RoundTrip executes a single HTTP transaction using the Transport's
  215. // Token as authorization headers.
  216. //
  217. // This method will attempt to renew the Token if it has expired and may return
  218. // an error related to that Token renewal before attempting the client request.
  219. // If the Token cannot be renewed a non-nil os.Error value will be returned.
  220. // If the Token is invalid callers should expect HTTP-level errors,
  221. // as indicated by the Response's StatusCode.
  222. func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
  223. if t.Token == nil {
  224. if t.Config == nil {
  225. return nil, OAuthError{"RoundTrip", "no Config supplied"}
  226. }
  227. if t.TokenCache == nil {
  228. return nil, OAuthError{"RoundTrip", "no Token supplied"}
  229. }
  230. var err error
  231. t.Token, err = t.TokenCache.Token()
  232. if err != nil {
  233. return nil, err
  234. }
  235. }
  236. // Refresh the Token if it has expired.
  237. if t.Expired() {
  238. if err := t.Refresh(); err != nil {
  239. return nil, err
  240. }
  241. }
  242. // To set the Authorization header, we must make a copy of the Request
  243. // so that we don't modify the Request we were given.
  244. // This is required by the specification of http.RoundTripper.
  245. req = cloneRequest(req)
  246. req.Header.Set("Authorization", "Bearer "+t.AccessToken)
  247. // Make the HTTP request.
  248. return t.transport().RoundTrip(req)
  249. }
  250. // cloneRequest returns a clone of the provided *http.Request.
  251. // The clone is a shallow copy of the struct and its Header map.
  252. func cloneRequest(r *http.Request) *http.Request {
  253. // shallow copy of the struct
  254. r2 := new(http.Request)
  255. *r2 = *r
  256. // deep copy of the Header
  257. r2.Header = make(http.Header)
  258. for k, s := range r.Header {
  259. r2.Header[k] = s
  260. }
  261. return r2
  262. }
  263. // Refresh renews the Transport's AccessToken using its RefreshToken.
  264. func (t *Transport) Refresh() error {
  265. if t.Token == nil {
  266. return OAuthError{"Refresh", "no existing Token"}
  267. }
  268. if t.RefreshToken == "" {
  269. return OAuthError{"Refresh", "Token expired; no Refresh Token"}
  270. }
  271. if t.Config == nil {
  272. return OAuthError{"Refresh", "no Config supplied"}
  273. }
  274. err := t.updateToken(t.Token, url.Values{
  275. "grant_type": {"refresh_token"},
  276. "refresh_token": {t.RefreshToken},
  277. })
  278. if err != nil {
  279. return err
  280. }
  281. if t.TokenCache != nil {
  282. return t.TokenCache.PutToken(t.Token)
  283. }
  284. return nil
  285. }
  286. // AuthenticateClient gets an access Token using the client_credentials grant
  287. // type.
  288. func (t *Transport) AuthenticateClient() error {
  289. if t.Config == nil {
  290. return OAuthError{"Exchange", "no Config supplied"}
  291. }
  292. if t.Token == nil {
  293. t.Token = &Token{}
  294. }
  295. return t.updateToken(t.Token, url.Values{"grant_type": {"client_credentials"}})
  296. }
  297. func (t *Transport) updateToken(tok *Token, v url.Values) error {
  298. v.Set("client_id", t.ClientId)
  299. v.Set("client_secret", t.ClientSecret)
  300. client := &http.Client{Transport: t.transport()}
  301. req, err := http.NewRequest("POST", t.TokenURL, strings.NewReader(v.Encode()))
  302. if err != nil {
  303. return err
  304. }
  305. req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
  306. req.SetBasicAuth(t.ClientId, t.ClientSecret)
  307. r, err := client.Do(req)
  308. if err != nil {
  309. return err
  310. }
  311. defer r.Body.Close()
  312. if r.StatusCode != 200 {
  313. return OAuthError{"updateToken", r.Status}
  314. }
  315. var b struct {
  316. Access string `json:"access_token"`
  317. Refresh string `json:"refresh_token"`
  318. ExpiresIn time.Duration `json:"expires_in"`
  319. Id string `json:"id_token"`
  320. }
  321. content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
  322. switch content {
  323. case "application/x-www-form-urlencoded", "text/plain":
  324. body, err := ioutil.ReadAll(r.Body)
  325. if err != nil {
  326. return err
  327. }
  328. vals, err := url.ParseQuery(string(body))
  329. if err != nil {
  330. return err
  331. }
  332. b.Access = vals.Get("access_token")
  333. b.Refresh = vals.Get("refresh_token")
  334. b.ExpiresIn, _ = time.ParseDuration(vals.Get("expires_in") + "s")
  335. b.Id = vals.Get("id_token")
  336. default:
  337. if err = json.NewDecoder(r.Body).Decode(&b); err != nil {
  338. return err
  339. }
  340. // The JSON parser treats the unitless ExpiresIn like 'ns' instead of 's' as above,
  341. // so compensate here.
  342. b.ExpiresIn *= time.Second
  343. }
  344. tok.AccessToken = b.Access
  345. // Don't overwrite `RefreshToken` with an empty value
  346. if len(b.Refresh) > 0 {
  347. tok.RefreshToken = b.Refresh
  348. }
  349. if b.ExpiresIn == 0 {
  350. tok.Expiry = time.Time{}
  351. } else {
  352. tok.Expiry = time.Now().Add(b.ExpiresIn)
  353. }
  354. if b.Id != "" {
  355. if tok.Extra == nil {
  356. tok.Extra = make(map[string]string)
  357. }
  358. tok.Extra["id_token"] = b.Id
  359. }
  360. return nil
  361. }