1// Package scram implements the SCRAM-SHA-* SASL authentication mechanism, RFC 7677 and RFC 5802.
2//
3// SCRAM-SHA-256 and SCRAM-SHA-1 allow a client to authenticate to a server using a
4// password without handing plaintext password over to the server. The client also
5// verifies the server knows (a derivative of) the password.
6package scram
7
8// todo: test with messages that contains extensions
9// todo: some tests for the parser
10// todo: figure out how invalid parameters etc should be handled. just abort? perhaps mostly a problem for imap.
11
12import (
13 "bytes"
14 "crypto/hmac"
15 cryptorand "crypto/rand"
16 "encoding/base64"
17 "errors"
18 "fmt"
19 "hash"
20 "strings"
21
22 "golang.org/x/crypto/pbkdf2"
23 "golang.org/x/text/unicode/norm"
24)
25
26// Errors at scram protocol level. Can be exchanged between client and server.
27var (
28 ErrInvalidEncoding Error = "invalid-encoding"
29 ErrExtensionsNotSupported Error = "extensions-not-supported"
30 ErrInvalidProof Error = "invalid-proof"
31 ErrChannelBindingsDontMatch Error = "channel-bindings-dont-match"
32 ErrServerDoesSupportChannelBinding Error = "server-does-support-channel-binding"
33 ErrChannelBindingNotSupported Error = "channel-binding-not-supported"
34 ErrUnsupportedChannelBindingType Error = "unsupported-channel-binding-type"
35 ErrUnknownUser Error = "unknown-user"
36 ErrNoResources Error = "no-resources"
37 ErrOtherError Error = "other-error"
38)
39
40var scramErrors = makeErrors()
41
42func makeErrors() map[string]Error {
43 l := []Error{
44 ErrInvalidEncoding,
45 ErrExtensionsNotSupported,
46 ErrInvalidProof,
47 ErrChannelBindingsDontMatch,
48 ErrServerDoesSupportChannelBinding,
49 ErrChannelBindingNotSupported,
50 ErrUnsupportedChannelBindingType,
51 ErrUnknownUser,
52 ErrNoResources,
53 ErrOtherError,
54 }
55 m := map[string]Error{}
56 for _, e := range l {
57 m[string(e)] = e
58 }
59 return m
60}
61
62var (
63 ErrNorm = errors.New("parameter not unicode normalized") // E.g. if client sends non-normalized username or authzid.
64 ErrUnsafe = errors.New("unsafe parameter") // E.g. salt, nonce too short, or too few iterations.
65 ErrProtocol = errors.New("protocol error") // E.g. server responded with a nonce not prefixed by the client nonce.
66)
67
68type Error string
69
70func (e Error) Error() string {
71 return string(e)
72}
73
74// MakeRandom returns a cryptographically random buffer for use as salt or as
75// nonce.
76func MakeRandom() []byte {
77 buf := make([]byte, 12)
78 _, err := cryptorand.Read(buf)
79 if err != nil {
80 panic("generate random")
81 }
82 return buf
83}
84
85// SaltPassword returns a salted password.
86func SaltPassword(h func() hash.Hash, password string, salt []byte, iterations int) []byte {
87 password = norm.NFC.String(password)
88 return pbkdf2.Key([]byte(password), salt, iterations, h().Size(), h)
89}
90
91// HMAC returns the hmac with key over msg.
92func HMAC(h func() hash.Hash, key []byte, msg string) []byte {
93 mac := hmac.New(h, key)
94 mac.Write([]byte(msg))
95 return mac.Sum(nil)
96}
97
98func xor(a, b []byte) {
99 for i := range a {
100 a[i] ^= b[i]
101 }
102}
103
104// Server represents the server-side of a SCRAM-SHA-* authentication.
105type Server struct {
106 Authentication string // Username for authentication, "authc". Always set and non-empty.
107 Authorization string // If set, role of user to assume after authentication, "authz".
108
109 h func() hash.Hash // sha1.New or sha256.New
110
111 // Messages used in hash calculations.
112 clientFirstBare string
113 serverFirst string
114 clientFinalWithoutProof string
115
116 gs2header string
117 clientNonce string // Client-part of the nonce.
118 serverNonceOverride string // If set, server does not generate random nonce, but uses this. For tests with the test vector.
119 nonce string // Full client + server nonce.
120}
121
122// NewServer returns a server given the first SCRAM message from a client.
123//
124// The sequence for data and calls on a server:
125//
126// - Read initial data from client, call NewServer (this call), then ServerFirst and write to the client.
127// - Read response from client, call Finish or FinishFinal and write the resulting string.
128func NewServer(h func() hash.Hash, clientFirst []byte) (server *Server, rerr error) {
129 p := newParser(clientFirst)
130 defer p.recover(&rerr)
131
132 server = &Server{h: h}
133
134 // ../rfc/5802:949 ../rfc/5802:910
135 gs2cbindFlag := p.xbyte()
136 switch gs2cbindFlag {
137 case 'n', 'y':
138 case 'p':
139 p.xerrorf("gs2 header with p: %w", ErrChannelBindingNotSupported)
140 }
141 p.xtake(",")
142 if !p.take(",") {
143 server.Authorization = p.xauthzid()
144 if norm.NFC.String(server.Authorization) != server.Authorization {
145 return nil, fmt.Errorf("%w: authzid", ErrNorm)
146 }
147 p.xtake(",")
148 }
149 server.gs2header = p.s[:p.o]
150 server.clientFirstBare = p.s[p.o:]
151
152 // ../rfc/5802:945
153 if p.take("m=") {
154 p.xerrorf("unexpected mandatory extension: %w", ErrExtensionsNotSupported)
155 }
156 server.Authentication = p.xusername()
157 if norm.NFC.String(server.Authentication) != server.Authentication {
158 return nil, fmt.Errorf("%w: username", ErrNorm)
159 }
160 p.xtake(",")
161 server.clientNonce = p.xnonce()
162 if len(server.clientNonce) < 8 {
163 return nil, fmt.Errorf("%w: client nonce too short", ErrUnsafe)
164 }
165 // Extensions, we don't recognize them.
166 for p.take(",") {
167 p.xattrval()
168 }
169 p.xempty()
170 return server, nil
171}
172
173// ServerFirst returns the string to send back to the client. To be called after NewServer.
174func (s *Server) ServerFirst(iterations int, salt []byte) (string, error) {
175 // ../rfc/5802:959
176 serverNonce := s.serverNonceOverride
177 if serverNonce == "" {
178 serverNonce = base64.StdEncoding.EncodeToString(MakeRandom())
179 }
180 s.nonce = s.clientNonce + serverNonce
181 s.serverFirst = fmt.Sprintf("r=%s,s=%s,i=%d", s.nonce, base64.StdEncoding.EncodeToString(salt), iterations)
182 return s.serverFirst, nil
183}
184
185// Finish takes the final client message, and the salted password (probably
186// from server storage), verifies the client, and returns a message to return
187// to the client. If err is nil, authentication was successful. If the
188// authorization requested is not acceptable, the server should call
189// FinishError instead.
190func (s *Server) Finish(clientFinal []byte, saltedPassword []byte) (serverFinal string, rerr error) {
191 p := newParser(clientFinal)
192 defer p.recover(&rerr)
193
194 cbind := p.xchannelBinding()
195 if cbind != s.gs2header {
196 return "e=" + string(ErrChannelBindingsDontMatch), ErrChannelBindingsDontMatch
197 }
198 p.xtake(",")
199 nonce := p.xnonce()
200 if nonce != s.nonce {
201 return "e=" + string(ErrInvalidProof), ErrInvalidProof
202 }
203 for !p.peek(",p=") {
204 p.xtake(",")
205 p.xattrval() // Ignored.
206 }
207 s.clientFinalWithoutProof = p.s[:p.o]
208 p.xtake(",")
209 proof := p.xproof()
210 p.xempty()
211
212 msg := s.clientFirstBare + "," + s.serverFirst + "," + s.clientFinalWithoutProof
213
214 clientKey := HMAC(s.h, saltedPassword, "Client Key")
215 h := s.h()
216 h.Write(clientKey)
217 storedKey := h.Sum(nil)
218
219 clientSig := HMAC(s.h, storedKey, msg)
220 xor(clientSig, clientKey) // Now clientProof.
221 if !bytes.Equal(clientSig, proof) {
222 return "e=" + string(ErrInvalidProof), ErrInvalidProof
223 }
224
225 serverKey := HMAC(s.h, saltedPassword, "Server Key")
226 serverSig := HMAC(s.h, serverKey, msg)
227 return fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(serverSig)), nil
228}
229
230// FinishError returns an error message to write to the client for the final
231// server message.
232func (s *Server) FinishError(err Error) string {
233 return "e=" + string(err)
234}
235
236// Client represents the client-side of a SCRAM-SHA-* authentication.
237type Client struct {
238 authc string
239 authz string
240
241 h func() hash.Hash // sha1.New or sha256.New
242
243 // Messages used in hash calculations.
244 clientFirstBare string
245 serverFirst string
246 clientFinalWithoutProof string
247 authMessage string
248
249 gs2header string
250 clientNonce string
251 nonce string // Full client + server nonce.
252 saltedPassword []byte
253}
254
255// NewClient returns a client for authentication authc, optionally for
256// authorization with role authz, for the hash (sha1.New or sha256.New).
257//
258// The sequence for data and calls on a client:
259//
260// - ClientFirst, write result to server.
261// - Read response from server, feed to ServerFirst, write response to server.
262// - Read response from server, feed to ServerFinal.
263func NewClient(h func() hash.Hash, authc, authz string) *Client {
264 authc = norm.NFC.String(authc)
265 authz = norm.NFC.String(authz)
266 return &Client{authc: authc, authz: authz, h: h}
267}
268
269// ClientFirst returns the first client message to write to the server.
270// No channel binding is done/supported.
271// A random nonce is generated.
272func (c *Client) ClientFirst() (clientFirst string, rerr error) {
273 c.gs2header = fmt.Sprintf("n,%s,", saslname(c.authz))
274 if c.clientNonce == "" {
275 c.clientNonce = base64.StdEncoding.EncodeToString(MakeRandom())
276 }
277 c.clientFirstBare = fmt.Sprintf("n=%s,r=%s", saslname(c.authc), c.clientNonce)
278 return c.gs2header + c.clientFirstBare, nil
279}
280
281// ServerFirst processes the first response message from the server. The
282// provided nonce, salt and iterations are checked. If valid, a final client
283// message is calculated and returned. This message must be written to the
284// server. It includes proof that the client knows the password.
285func (c *Client) ServerFirst(serverFirst []byte, password string) (clientFinal string, rerr error) {
286 c.serverFirst = string(serverFirst)
287 p := newParser(serverFirst)
288 defer p.recover(&rerr)
289
290 // ../rfc/5802:959
291 if p.take("m=") {
292 p.xerrorf("unsupported mandatory extension: %w", ErrExtensionsNotSupported)
293 }
294
295 c.nonce = p.xnonce()
296 p.xtake(",")
297 salt := p.xsalt()
298 p.xtake(",")
299 iterations := p.xiterations()
300 // We ignore extensions that we don't know about.
301 for p.take(",") {
302 p.xattrval()
303 }
304 p.xempty()
305
306 if !strings.HasPrefix(c.nonce, c.clientNonce) {
307 return "", fmt.Errorf("%w: server dropped our nonce", ErrProtocol)
308 }
309 if len(c.nonce)-len(c.clientNonce) < 8 {
310 return "", fmt.Errorf("%w: server nonce too short", ErrUnsafe)
311 }
312 if len(salt) < 8 {
313 return "", fmt.Errorf("%w: salt too short", ErrUnsafe)
314 }
315 if iterations < 2048 {
316 return "", fmt.Errorf("%w: too few iterations", ErrUnsafe)
317 }
318
319 c.clientFinalWithoutProof = fmt.Sprintf("c=%s,r=%s", base64.StdEncoding.EncodeToString([]byte(c.gs2header)), c.nonce)
320
321 c.authMessage = c.clientFirstBare + "," + c.serverFirst + "," + c.clientFinalWithoutProof
322
323 c.saltedPassword = SaltPassword(c.h, password, salt, iterations)
324 clientKey := HMAC(c.h, c.saltedPassword, "Client Key")
325 h := c.h()
326 h.Write(clientKey)
327 storedKey := h.Sum(nil)
328 clientSig := HMAC(c.h, storedKey, c.authMessage)
329 xor(clientSig, clientKey) // Now clientProof.
330 clientProof := clientSig
331
332 r := c.clientFinalWithoutProof + ",p=" + base64.StdEncoding.EncodeToString(clientProof)
333 return r, nil
334}
335
336// ServerFinal processes the final message from the server, verifying that the
337// server knows the password.
338func (c *Client) ServerFinal(serverFinal []byte) (rerr error) {
339 p := newParser(serverFinal)
340 defer p.recover(&rerr)
341
342 if p.take("e=") {
343 errstr := p.xvalue()
344 var err error = scramErrors[errstr]
345 if err == Error("") {
346 err = errors.New(errstr)
347 }
348 return fmt.Errorf("error from server: %w", err)
349 }
350 p.xtake("v=")
351 verifier := p.xbase64()
352
353 serverKey := HMAC(c.h, c.saltedPassword, "Server Key")
354 serverSig := HMAC(c.h, serverKey, c.authMessage)
355 if !bytes.Equal(verifier, serverSig) {
356 return fmt.Errorf("incorrect server signature")
357 }
358 return nil
359}
360
361// Convert "," to =2C and "=" to =3D.
362func saslname(s string) string {
363 var r string
364 for _, c := range s {
365 if c == ',' {
366 r += "=2C"
367 } else if c == '=' {
368 r += "=3D"
369 } else {
370 r += string(c)
371 }
372 }
373 return r
374}
375