Skip to content

Commit 4b83411

Browse files
nwidgerbradfitz
authored andcommitted
jwt: add Config.Audience field
Add an Audience field to jwt.Config which, if set, is used instead of TokenURL as the 'aud' claim in the generated JWT. This allows the jwt package to work with authorization servers that require the 'aud' claim and token endpoint URL to be different values. Fixes #369. Change-Id: I883aabece7f9b16ec726d5bfa98c1ec91876b651 GitHub-Last-Rev: fd73e4d GitHub-Pull-Request: #370 Reviewed-on: https://go-review.googlesource.com/c/162937 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
1 parent 3e8b2be commit 4b83411

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

jwt/jwt.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ type Config struct {
6161

6262
// Expires optionally specifies how long the token is valid for.
6363
Expires time.Duration
64+
65+
// Audience optionally specifies the intended audience of the
66+
// request. If empty, the value of TokenURL is used as the
67+
// intended audience.
68+
Audience string
6469
}
6570

6671
// TokenSource returns a JWT TokenSource using the configuration
@@ -105,6 +110,9 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
105110
if t := js.conf.Expires; t > 0 {
106111
claimSet.Exp = time.Now().Add(t).Unix()
107112
}
113+
if aud := js.conf.Audience; aud != "" {
114+
claimSet.Aud = aud
115+
}
108116
h := *defaultHeader
109117
h.KeyID = js.conf.PrivateKeyID
110118
payload, err := jws.Encode(&h, claimSet, pk)

jwt/jwt_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,80 @@ func TestJWTFetch_Assertion(t *testing.T) {
191191
}
192192
}
193193

194+
func TestJWTFetch_AssertionPayload(t *testing.T) {
195+
var assertion string
196+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
197+
r.ParseForm()
198+
assertion = r.Form.Get("assertion")
199+
200+
w.Header().Set("Content-Type", "application/json")
201+
w.Write([]byte(`{
202+
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
203+
"scope": "user",
204+
"token_type": "bearer",
205+
"expires_in": 3600
206+
}`))
207+
}))
208+
defer ts.Close()
209+
210+
for _, conf := range []*Config{
211+
{
212+
Email: "aaa1@xxx.com",
213+
PrivateKey: dummyPrivateKey,
214+
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
215+
TokenURL: ts.URL,
216+
},
217+
{
218+
Email: "aaa2@xxx.com",
219+
PrivateKey: dummyPrivateKey,
220+
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
221+
TokenURL: ts.URL,
222+
Audience: "https://example.com",
223+
},
224+
} {
225+
t.Run(conf.Email, func(t *testing.T) {
226+
_, err := conf.TokenSource(context.Background()).Token()
227+
if err != nil {
228+
t.Fatalf("Failed to fetch token: %v", err)
229+
}
230+
231+
parts := strings.Split(assertion, ".")
232+
if len(parts) != 3 {
233+
t.Fatalf("assertion = %q; want 3 parts", assertion)
234+
}
235+
gotjson, err := base64.RawURLEncoding.DecodeString(parts[1])
236+
if err != nil {
237+
t.Fatalf("invalid token payload; err = %v", err)
238+
}
239+
240+
claimSet := jws.ClaimSet{}
241+
if err := json.Unmarshal(gotjson, &claimSet); err != nil {
242+
t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err)
243+
}
244+
245+
if got, want := claimSet.Iss, conf.Email; got != want {
246+
t.Errorf("payload email = %q; want %q", got, want)
247+
}
248+
if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want {
249+
t.Errorf("payload scope = %q; want %q", got, want)
250+
}
251+
aud := conf.TokenURL
252+
if conf.Audience != "" {
253+
aud = conf.Audience
254+
}
255+
if got, want := claimSet.Aud, aud; got != want {
256+
t.Errorf("payload audience = %q; want %q", got, want)
257+
}
258+
if got, want := claimSet.Sub, conf.Subject; got != want {
259+
t.Errorf("payload subject = %q; want %q", got, want)
260+
}
261+
if got, want := claimSet.Prn, conf.Subject; got != want {
262+
t.Errorf("payload prn = %q; want %q", got, want)
263+
}
264+
})
265+
}
266+
}
267+
194268
func TestTokenRetrieveError(t *testing.T) {
195269
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
196270
w.Header().Set("Content-type", "application/json")

0 commit comments

Comments
 (0)