5	cryptorand "crypto/rand"
 
19func base64Decode(s string) []byte {
 
20	buf, err := base64.StdEncoding.DecodeString(s)
 
27func tcheck(t *testing.T, err error, msg string) {
 
30		t.Fatalf("%s: %s", msg, err)
 
34func TestSCRAMSHA1Server(t *testing.T) {
 
36	salt := base64Decode("QSXCR+Q6sek8bf92")
 
37	saltedPassword := SaltPassword(sha1.New, "pencil", salt, 4096)
 
39	server, err := NewServer(sha1.New, []byte("n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL"), nil, false)
 
40	server.serverNonceOverride = "3rfcNHYJY1ZVvWVs7j"
 
41	tcheck(t, err, "newserver")
 
42	resp, err := server.ServerFirst(4096, salt)
 
43	tcheck(t, err, "server first")
 
44	if resp != "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096" {
 
45		t.Fatalf("bad server first")
 
47	serverFinal, err := server.Finish([]byte("c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts="), saltedPassword)
 
48	tcheck(t, err, "finish")
 
49	if serverFinal != "v=rmF9pqV8S7suAoZWja4dJRkFsKQ=" {
 
50		t.Fatalf("bad server final")
 
54func TestSCRAMSHA256Server(t *testing.T) {
 
56	salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
 
57	saltedPassword := SaltPassword(sha256.New, "pencil", salt, 4096)
 
59	server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
 
60	server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
 
61	tcheck(t, err, "newserver")
 
62	resp, err := server.ServerFirst(4096, salt)
 
63	tcheck(t, err, "server first")
 
64	if resp != "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" {
 
65		t.Fatalf("bad server first")
 
67	serverFinal, err := server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
 
68	tcheck(t, err, "finish")
 
69	if serverFinal != "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=" {
 
70		t.Fatalf("bad server final")
 
74// Bad attempt with wrong password.
 
75func TestScramServerBadPassword(t *testing.T) {
 
76	salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
 
77	saltedPassword := SaltPassword(sha256.New, "marker", salt, 4096)
 
79	server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
 
80	server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
 
81	tcheck(t, err, "newserver")
 
82	_, err = server.ServerFirst(4096, salt)
 
83	tcheck(t, err, "server first")
 
84	_, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
 
85	if !errors.Is(err, ErrInvalidProof) {
 
86		t.Fatalf("got %v, expected ErrInvalidProof", err)
 
90// Bad attempt with different number of rounds.
 
91func TestScramServerBadIterations(t *testing.T) {
 
92	salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
 
93	saltedPassword := SaltPassword(sha256.New, "pencil", salt, 2048)
 
95	server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
 
96	server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
 
97	tcheck(t, err, "newserver")
 
98	_, err = server.ServerFirst(4096, salt)
 
99	tcheck(t, err, "server first")
 
100	_, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
 
101	if !errors.Is(err, ErrInvalidProof) {
 
102		t.Fatalf("got %v, expected ErrInvalidProof", err)
 
106// Another attempt but with a randomly different nonce.
 
107func TestScramServerBad(t *testing.T) {
 
108	salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
 
109	saltedPassword := SaltPassword(sha256.New, "pencil", salt, 4096)
 
111	server, err := NewServer(sha256.New, []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"), nil, false)
 
112	tcheck(t, err, "newserver")
 
113	_, err = server.ServerFirst(4096, salt)
 
114	tcheck(t, err, "server first")
 
115	_, err = server.Finish([]byte("c=biws,r="+server.nonce+",p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
 
116	if !errors.Is(err, ErrInvalidProof) {
 
117		t.Fatalf("got %v, expected ErrInvalidProof", err)
 
121func TestScramClient(t *testing.T) {
 
122	c := NewClient(sha256.New, "user", "", false, nil)
 
123	c.clientNonce = "rOprNGfwEbeRWgbNEkqO"
 
124	clientFirst, err := c.ClientFirst()
 
125	tcheck(t, err, "ClientFirst")
 
126	if clientFirst != "n,,n=user,r=rOprNGfwEbeRWgbNEkqO" {
 
127		t.Fatalf("bad clientFirst")
 
129	clientFinal, err := c.ServerFirst([]byte("r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"), "pencil")
 
130	tcheck(t, err, "ServerFirst")
 
131	if clientFinal != "c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ=" {
 
132		t.Fatalf("bad clientFinal")
 
134	err = c.ServerFinal([]byte("v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="))
 
135	tcheck(t, err, "ServerFinal")
 
138func TestScram(t *testing.T) {
 
139	runHash := func(h func() hash.Hash, expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string, noServerPlus bool, clientcs, servercs *tls.ConnectionState) {
 
144			if x == nil || x == "" {
 
150		// check err is either nil or the expected error. if the expected error, panic to abort the authentication session.
 
151		xerr := func(err error, msg string) {
 
153			if err != nil && !errors.Is(err, expErr) {
 
154				t.Fatalf("%s: got %v, expected %v", msg, err, expErr)
 
157				panic("") // Abort test.
 
162		saltedPassword := SaltPassword(h, password, salt, iterations)
 
164		client := NewClient(h, username, "", noServerPlus, clientcs)
 
165		client.clientNonce = clientNonce
 
166		clientFirst, err := client.ClientFirst()
 
167		xerr(err, "client.ClientFirst")
 
169		server, err := NewServer(h, []byte(clientFirst), servercs, servercs != nil)
 
170		xerr(err, "NewServer")
 
171		server.serverNonceOverride = serverNonce
 
173		serverFirst, err := server.ServerFirst(iterations, salt)
 
174		xerr(err, "server.ServerFirst")
 
176		clientFinal, err := client.ServerFirst([]byte(serverFirst), password)
 
177		xerr(err, "client.ServerFirst")
 
179		serverFinal, err := server.Finish([]byte(clientFinal), saltedPassword)
 
180		xerr(err, "server.Finish")
 
182		err = client.ServerFinal([]byte(serverFinal))
 
183		xerr(err, "client.ServerFinal")
 
186			t.Fatalf("got no error, expected %v", expErr)
 
190	makeState := func(maxTLSVersion uint16) (tls.ConnectionState, tls.ConnectionState) {
 
191		client, server := net.Pipe()
 
194		tlsClient := tls.Client(client, &tls.Config{
 
195			InsecureSkipVerify: true,
 
196			MaxVersion:         maxTLSVersion,
 
198		tlsServer := tls.Server(server, &tls.Config{
 
199			Certificates: []tls.Certificate{fakeCert(t, "mox.example", false)},
 
200			MaxVersion:   maxTLSVersion,
 
202		errc := make(chan error, 1)
 
204			errc <- tlsServer.Handshake()
 
206		err := tlsClient.Handshake()
 
207		tcheck(t, err, "tls handshake")
 
209		tcheck(t, err, "server tls handshake")
 
210		clientcs := tlsClient.ConnectionState()
 
211		servercs := tlsServer.ConnectionState()
 
213		return clientcs, servercs
 
216	runPlus := func(maxTLSVersion uint16, expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string) {
 
220		clientcs, servercs := makeState(maxTLSVersion)
 
221		runHash(sha1.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, &clientcs, &servercs)
 
222		runHash(sha256.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, &clientcs, &servercs)
 
225	run := func(expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string) {
 
229		runHash(sha1.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, nil, nil)
 
230		runHash(sha256.New, expErr, username, authzid, password, iterations, clientNonce, serverNonce, false, nil, nil)
 
232		// Check with both TLS 1.2 for "tls-unique", and latest TLS for "tls-exporter".
 
233		runPlus(tls.VersionTLS12, expErr, username, authzid, password, iterations, clientNonce, serverNonce)
 
234		runPlus(0, expErr, username, authzid, password, iterations, clientNonce, serverNonce)
 
237	run(nil, "user", "", "pencil", 4096, "", "")
 
238	run(nil, "mjl@mox.example", "", "testtest", 4096, "", "")
 
239	run(nil, "mjl@mox.example", "", "short", 4096, "", "")
 
240	run(nil, "mjl@mox.example", "", "short", 2048, "", "")
 
241	run(nil, "mjl@mox.example", "mjl@mox.example", "testtest", 4096, "", "")
 
242	run(nil, "mjl@mox.example", "other@mox.example", "testtest", 4096, "", "")
 
243	run(nil, "mjl@mox.example", "other@mox.example", "testtest", 4096, "", "")
 
244	run(ErrUnsafe, "user", "", "pencil", 1, "", "")                // Few iterations.
 
245	run(ErrUnsafe, "user", "", "pencil", 2048, "short", "")        // Short client nonce.
 
246	run(ErrUnsafe, "user", "", "pencil", 2048, "test1234", "test") // Server added too few random data.
 
248	// Test mechanism downgrade attacks are detected.
 
249	runHash(sha1.New, ErrServerDoesSupportChannelBinding, "user", "", "pencil", 4096, "", "", true, nil, nil)
 
250	runHash(sha256.New, ErrServerDoesSupportChannelBinding, "user", "", "pencil", 4096, "", "", true, nil, nil)
 
252	// Test channel binding, detecting MitM attacks.
 
253	runChannelBind := func(maxTLSVersion uint16) {
 
256		clientcs0, _ := makeState(maxTLSVersion)
 
257		_, servercs1 := makeState(maxTLSVersion)
 
258		runHash(sha1.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, &servercs1)
 
259		runHash(sha256.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, &servercs1)
 
261		// Client thinks it is on a TLS connection and server is not.
 
262		runHash(sha1.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, nil)
 
263		runHash(sha256.New, ErrChannelBindingsDontMatch, "user", "", "pencil", 4096, "", "", false, &clientcs0, nil)
 
267	runChannelBind(tls.VersionTLS12)
 
270// Just a cert that appears valid.
 
271func fakeCert(t *testing.T, name string, expired bool) tls.Certificate {
 
272	notAfter := time.Now()
 
274		notAfter = notAfter.Add(-time.Hour)
 
276		notAfter = notAfter.Add(time.Hour)
 
279	privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
 
280	template := &x509.Certificate{
 
281		SerialNumber: big.NewInt(1), // Required field...
 
282		DNSNames:     []string{name},
 
283		NotBefore:    time.Now().Add(-time.Hour),
 
286	localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
 
288		t.Fatalf("making certificate: %s", err)
 
290	cert, err := x509.ParseCertificate(localCertBuf)
 
292		t.Fatalf("parsing generated certificate: %s", err)
 
294	c := tls.Certificate{
 
295		Certificate: [][]byte{localCertBuf},