1package scram
2
3import (
4 "crypto/ed25519"
5 cryptorand "crypto/rand"
6 "crypto/sha1"
7 "crypto/sha256"
8 "crypto/tls"
9 "crypto/x509"
10 "encoding/base64"
11 "errors"
12 "hash"
13 "math/big"
14 "net"
15 "testing"
16 "time"
17)
18
19func base64Decode(s string) []byte {
20 buf, err := base64.StdEncoding.DecodeString(s)
21 if err != nil {
22 panic("bad base64")
23 }
24 return buf
25}
26
27func tcheck(t *testing.T, err error, msg string) {
28 t.Helper()
29 if err != nil {
30 t.Fatalf("%s: %s", msg, err)
31 }
32}
33
34func TestSCRAMSHA1Server(t *testing.T) {
35 // Test vector from ../rfc/5802:496
36 salt := base64Decode("QSXCR+Q6sek8bf92")
37 saltedPassword, err := SaltPassword(sha1.New, "pencil", salt, 4096)
38 tcheck(t, err, "saltpassword")
39
40 server, err := NewServer(sha1.New, []byte("n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL"), nil, false)
41 server.serverNonceOverride = "3rfcNHYJY1ZVvWVs7j"
42 tcheck(t, err, "newserver")
43 resp, err := server.ServerFirst(4096, salt)
44 tcheck(t, err, "server first")
45 if resp != "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096" {
46 t.Fatalf("bad server first")
47 }
48 serverFinal, err := server.Finish([]byte("c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts="), saltedPassword)
49 tcheck(t, err, "finish")
50 if serverFinal != "v=rmF9pqV8S7suAoZWja4dJRkFsKQ=" {
51 t.Fatalf("bad server final")
52 }
53}
54
55func TestSCRAMSHA256Server(t *testing.T) {
56 // Test vector from ../rfc/7677:122
57 salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
58 saltedPassword, err := SaltPassword(sha256.New, "pencil", salt, 4096)
59 tcheck(t, err, "saltpassword")
60
61 server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
62 server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
63 tcheck(t, err, "newserver")
64 resp, err := server.ServerFirst(4096, salt)
65 tcheck(t, err, "server first")
66 if resp != "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" {
67 t.Fatalf("bad server first")
68 }
69 serverFinal, err := server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
70 tcheck(t, err, "finish")
71 if serverFinal != "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=" {
72 t.Fatalf("bad server final")
73 }
74}
75
76// Bad attempt with wrong password.
77func TestScramServerBadPassword(t *testing.T) {
78 salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
79 saltedPassword, err := SaltPassword(sha256.New, "marker", salt, 4096)
80 tcheck(t, err, "saltpassword")
81
82 server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
83 server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
84 tcheck(t, err, "newserver")
85 _, err = server.ServerFirst(4096, salt)
86 tcheck(t, err, "server first")
87 _, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
88 if !errors.Is(err, ErrInvalidProof) {
89 t.Fatalf("got %v, expected ErrInvalidProof", err)
90 }
91}
92
93// Bad attempt with different number of rounds.
94func TestScramServerBadIterations(t *testing.T) {
95 salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
96 saltedPassword, err := SaltPassword(sha256.New, "pencil", salt, 2048)
97 tcheck(t, err, "saltpassword")
98
99 server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
100 server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
101 tcheck(t, err, "newserver")
102 _, err = server.ServerFirst(4096, salt)
103 tcheck(t, err, "server first")
104 _, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
105 if !errors.Is(err, ErrInvalidProof) {
106 t.Fatalf("got %v, expected ErrInvalidProof", err)
107 }
108}
109
110// Another attempt but with a randomly different nonce.
111func TestScramServerBad(t *testing.T) {
112 salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
113 saltedPassword, err := SaltPassword(sha256.New, "pencil", salt, 4096)
114 tcheck(t, err, "saltpassword")
115
116 server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
117 tcheck(t, err, "newserver")
118 _, err = server.ServerFirst(4096, salt)
119 tcheck(t, err, "server first")
120 _, err = server.Finish([]byte("c=biws,r="+server.nonce+",p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
121 if !errors.Is(err, ErrInvalidProof) {
122 t.Fatalf("got %v, expected ErrInvalidProof", err)
123 }
124}
125
126func TestScramClient(t *testing.T) {
127 c := NewClient(sha256.New, "user", "", false, nil)
128 c.clientNonce = "rOprNGfwEbeRWgbNEkqO"
129 clientFirst, err := c.ClientFirst()
130 tcheck(t, err, "ClientFirst")
131 if clientFirst != "n,,n=user,r=rOprNGfwEbeRWgbNEkqO" {
132 t.Fatalf("bad clientFirst")
133 }
134 clientFinal, err := c.ServerFirst([]byte("r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"), "pencil")
135 tcheck(t, err, "ServerFirst")
136 if clientFinal != "c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ=" {
137 t.Fatalf("bad clientFinal")
138 }
139 err = c.ServerFinal([]byte("v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="))
140 tcheck(t, err, "ServerFinal")
141}
142
143func TestScram(t *testing.T) {
144 runHash := func(h func() hash.Hash, expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string, noServerPlus bool, clientcs, servercs *tls.ConnectionState) {
145 t.Helper()
146
147 defer func() {
148 x := recover()
149 if x == nil || x == "" {
150 return
151 }
152 panic(x)
153 }()
154
155 // check err is either nil or the expected error. if the expected error, panic to abort the authentication session.
156 xerr := func(err error, msg string) {
157 t.Helper()
158 if err != nil && !errors.Is(err, expErr) {
159 t.Fatalf("%s: got %v, expected %v", msg, err, expErr)
160 }
161 if err != nil {
162 panic("") // Abort test.
163 }
164 }
165
166 salt := MakeRandom()
167 saltedPassword, err := SaltPassword(h, password, salt, iterations)
168 tcheck(t, err, "saltpassword")
169
170 client := NewClient(h, username, "", noServerPlus, clientcs)
171 client.clientNonce = clientNonce
172 clientFirst, err := client.ClientFirst()
173 xerr(err, "client.ClientFirst")
174
175 server, err := NewServer(h, []byte(clientFirst), servercs, servercs != nil)
176 xerr(err, "NewServer")
177 server.serverNonceOverride = serverNonce
178
179 serverFirst, err := server.ServerFirst(iterations, salt)
180 xerr(err, "server.ServerFirst")
181
182 clientFinal, err := client.ServerFirst([]byte(serverFirst), password)
183 xerr(err, "client.ServerFirst")
184
185 serverFinal, err := server.Finish([]byte(clientFinal), saltedPassword)
186 xerr(err, "server.Finish")
187
188 err = client.ServerFinal([]byte(serverFinal))
189 xerr(err, "client.ServerFinal")
190
191 if expErr != nil {
192 t.Fatalf("got no error, expected %v", expErr)
193 }
194 }
195
196 makeState := func(maxTLSVersion uint16) (tls.ConnectionState, tls.ConnectionState) {
197 client, server := net.Pipe()
198 defer client.Close()
199 defer server.Close()
200 tlsClient := tls.Client(client, &tls.Config{
201 InsecureSkipVerify: true,
202 MaxVersion: maxTLSVersion,
203 })
204 tlsServer := tls.Server(server, &tls.Config{
205 Certificates: []tls.Certificate{fakeCert(t, "mox.example", false)},
206 MaxVersion: maxTLSVersion,
207 })
208 errc := make(chan error, 1)
209 go func() {
210 errc <- tlsServer.Handshake()
211 }()
212 err := tlsClient.Handshake()
213 tcheck(t, err, "tls handshake")
214 err = <-errc
215 tcheck(t, err, "server tls handshake")
216 clientcs := tlsClient.ConnectionState()
217 servercs := tlsServer.ConnectionState()
218
219 return clientcs, servercs
220 }
221
222 runPlus := func(maxTLSVersion uint16, expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string) {
223 t.Helper()
224
225 // PLUS variants.
226 clientcs, servercs := makeState(maxTLSVersion)
227 runHash(sha1.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, &clientcs, &servercs)
228 runHash(sha256.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, &clientcs, &servercs)
229 }
230
231 run := func(expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string) {
232 t.Helper()
233
234 // Bare variants
235 runHash(sha1.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, nil, nil)
236 runHash(sha256.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, nil, nil)
237
238 // Check with both TLS 1.2 for "tls-unique", and latest TLS for "tls-exporter".
239 runPlus(tls.VersionTLS12, expErr, username, authzid, password, iterations, clientNonce, serverNonce)
240 runPlus(0, expErr, username, authzid, password, iterations, clientNonce, serverNonce)
241 }
242
243 run(nil, "user", "", "pencil", 4096, "", "")
244 run(nil, "mjl@mox.example", "", "testtest", 4096, "", "")
245 run(nil, "mjl@mox.example", "", "short", 4096, "", "")
246 run(nil, "mjl@mox.example", "", "short", 2048, "", "")
247 run(nil, "mjl@mox.example", "mjl@mox.example", "testtest", 4096, "", "")
248 run(nil, "mjl@mox.example", "other@mox.example", "testtest", 4096, "", "")
249 run(nil, "mjl@mox.example", "other@mox.example", "testtest", 4096, "", "")
250 run(ErrUnsafe, "user", "", "pencil", 1, "", "") // Few iterations.
251 run(ErrUnsafe, "user", "", "pencil", 2048, "short", "") // Short client nonce.
252 run(ErrUnsafe, "user", "", "pencil", 2048, "test1234", "test") // Server added too few random data.
253
254 // Test mechanism downgrade attacks are detected.
255 runHash(sha1.New, ErrServerDoesSupportChannelBinding, "user", "", "pencil", 4096, "", "", true, nil, nil)
256 runHash(sha256.New, ErrServerDoesSupportChannelBinding, "user", "", "pencil", 4096, "", "", true, nil, nil)
257
258 // Test channel binding, detecting MitM attacks.
259 runChannelBind := func(maxTLSVersion uint16) {
260 t.Helper()
261
262 clientcs0, _ := makeState(maxTLSVersion)
263 _, servercs1 := makeState(maxTLSVersion)
264 runHash(sha1.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, &servercs1)
265 runHash(sha256.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, &servercs1)
266
267 // Client thinks it is on a TLS connection and server is not.
268 runHash(sha1.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, nil)
269 runHash(sha256.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, nil)
270 }
271
272 runChannelBind(0)
273 runChannelBind(tls.VersionTLS12)
274}
275
276// Just a cert that appears valid.
277func fakeCert(t *testing.T, name string, expired bool) tls.Certificate {
278 notAfter := time.Now()
279 if expired {
280 notAfter = notAfter.Add(-time.Hour)
281 } else {
282 notAfter = notAfter.Add(time.Hour)
283 }
284
285 privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
286 template := &x509.Certificate{
287 SerialNumber: big.NewInt(1), // Required field...
288 DNSNames: []string{name},
289 NotBefore: time.Now().Add(-time.Hour),
290 NotAfter: notAfter,
291 }
292 localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
293 if err != nil {
294 t.Fatalf("making certificate: %s", err)
295 }
296 cert, err := x509.ParseCertificate(localCertBuf)
297 if err != nil {
298 t.Fatalf("parsing generated certificate: %s", err)
299 }
300 c := tls.Certificate{
301 Certificate: [][]byte{localCertBuf},
302 PrivateKey: privKey,
303 Leaf: cert,
304 }
305 return c
306}
307