16	"golang.org/x/net/websocket"
 
18	"github.com/mjl-/mox/mox-"
 
21func tcheck(t *testing.T, err error, msg string) {
 
24		t.Fatalf("%s: %s", msg, err)
 
28func TestWebserver(t *testing.T) {
 
29	os.RemoveAll("../testdata/webserver/data")
 
30	mox.ConfigStaticPath = filepath.FromSlash("../testdata/webserver/mox.conf")
 
31	mox.ConfigDynamicPath = filepath.Join(filepath.Dir(mox.ConfigStaticPath), "domains.conf")
 
32	mox.MustLoadConfig(true, false)
 
34	loadStaticGzipCache(mox.DataDirPath("tmp/httpstaticcompresscache"), 1024*1024)
 
36	srv := &serve{Webserver: true}
 
38	test := func(method, target string, reqhdrs map[string]string, expCode int, expContent string, expHeaders map[string]string) {
 
41		req := httptest.NewRequest(method, target, nil)
 
42		for k, v := range reqhdrs {
 
45		rw := httptest.NewRecorder()
 
46		rw.Body = &bytes.Buffer{}
 
47		srv.ServeHTTP(rw, req)
 
49		if resp.StatusCode != expCode {
 
50			t.Fatalf("got statuscode %d, expected %d", resp.StatusCode, expCode)
 
55				t.Fatalf("got response data %q, expected %q", s, expContent)
 
58		for k, v := range expHeaders {
 
59			if xv := resp.Header.Get(k); xv != v {
 
60				t.Fatalf("got %q for header %q, expected %q", xv, k, v)
 
65	test("GET", "http://redir.mox.example", nil, http.StatusPermanentRedirect, "", map[string]string{"Location": "https://mox.example/"})
 
67	// http to https redirect, and stay on https afterwards without redirect loop.
 
68	test("GET", "http://schemeredir.example", nil, http.StatusPermanentRedirect, "", map[string]string{"Location": "https://schemeredir.example/"})
 
69	test("GET", "https://schemeredir.example", nil, http.StatusNotFound, "", nil)
 
71	accgzip := map[string]string{"Accept-Encoding": "gzip"}
 
72	test("GET", "http://mox.example/static/", accgzip, http.StatusOK, "", map[string]string{"X-Test": "mox", "Content-Encoding": "gzip"})       // index.html
 
73	test("GET", "http://mox.example/static/dir/hi.txt", accgzip, http.StatusOK, "", map[string]string{"X-Test": "mox", "Content-Encoding": ""}) // too small to compress
 
74	test("GET", "http://mox.example/static/dir/", accgzip, http.StatusOK, "", map[string]string{"X-Test": "mox", "Content-Encoding": "gzip"})   // listing
 
75	test("GET", "http://mox.example/static/dir", accgzip, http.StatusTemporaryRedirect, "", map[string]string{"Location": "/static/dir/"})      // redirect to dir
 
76	test("GET", "http://mox.example/static/bogus", accgzip, http.StatusNotFound, "", map[string]string{"Content-Encoding": ""})
 
78	test("GET", "http://mox.example/nolist/", nil, http.StatusOK, "", nil)            // index.html
 
79	test("GET", "http://mox.example/nolist/dir/", nil, http.StatusForbidden, "", nil) // no listing
 
81	test("GET", "http://mox.example/tls/", nil, http.StatusPermanentRedirect, "", map[string]string{"Location": "https://mox.example/tls/"}) // redirect to tls
 
83	test("GET", "http://mox.example/baseurl/x?y=2", nil, http.StatusPermanentRedirect, "", map[string]string{"Location": "https://tls.mox.example/baseurl/x?q=1&y=2#fragment"})
 
84	test("GET", "http://mox.example/pathonly/old/x?q=2", nil, http.StatusTemporaryRedirect, "", map[string]string{"Location": "http://mox.example/pathonly/new/x?q=2"})
 
85	test("GET", "http://mox.example/baseurlpath/old/x?y=2", nil, http.StatusPermanentRedirect, "", map[string]string{"Location": "//other.mox.example/baseurlpath/new/x?q=1&y=2#fragment"})
 
87	test("GET", "http://mox.example/strip/x", nil, http.StatusBadGateway, "", nil)   // no server yet
 
88	test("GET", "http://mox.example/nostrip/x", nil, http.StatusBadGateway, "", nil) // no server yet
 
90	badForwarded := map[string]string{
 
92		"X-Forwarded-For":   "bad",
 
93		"X-Forwarded-Proto": "bad",
 
94		"X-Forwarded-Host":  "bad",
 
95		"X-Forwarded-Ext":   "bad",
 
98	// Server that echoes path, and forwarded request headers.
 
99	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
100		for k, v := range badForwarded {
 
101			if r.Header.Get(k) == v {
 
102				w.WriteHeader(http.StatusInternalServerError)
 
107		for k, vl := range r.Header {
 
108			if k == "Forwarded" || k == "X-Forwarded" || strings.HasPrefix(k, "X-Forwarded-") {
 
112		w.Write([]byte(r.URL.Path))
 
116	serverURL, err := url.Parse(server.URL)
 
118		t.Fatalf("parsing url: %v", err)
 
120	serverURL.Path = "/a"
 
122	// warning: it is not normally allowed to access the dynamic config without lock. don't propagate accesses like this!
 
123	mox.Conf.Dynamic.WebHandlers[len(mox.Conf.Dynamic.WebHandlers)-2].WebForward.TargetURL = serverURL
 
124	mox.Conf.Dynamic.WebHandlers[len(mox.Conf.Dynamic.WebHandlers)-1].WebForward.TargetURL = serverURL
 
126	test("GET", "http://mox.example/strip/x", badForwarded, http.StatusOK, "/a/x", map[string]string{
 
128		"X-Forwarded-For":   "192.0.2.1", // IP is hardcoded in Go's src/net/http/httptest/httptest.go
 
129		"X-Forwarded-Proto": "http",
 
130		"X-Forwarded-Host":  "mox.example",
 
131		"X-Forwarded-Ext":   "",
 
133	test("GET", "http://mox.example/nostrip/x", map[string]string{"X-OK": "ok"}, http.StatusOK, "/a/nostrip/x", map[string]string{"X-Test": "mox"})
 
135	test("GET", "http://mox.example/bogus", nil, http.StatusNotFound, "", nil)         // path not registered.
 
136	test("GET", "http://bogus.mox.example/static/", nil, http.StatusNotFound, "", nil) // domain not registered.
 
138	npaths := len(staticgzcache.paths)
 
140		t.Fatalf("%d file(s) in staticgzcache, expected 1", npaths)
 
142	loadStaticGzipCache(mox.DataDirPath("tmp/httpstaticcompresscache"), 1024*1024)
 
143	npaths = len(staticgzcache.paths)
 
145		t.Fatalf("%d file(s) in staticgzcache after loading from disk, expected 1", npaths)
 
147	loadStaticGzipCache(mox.DataDirPath("tmp/httpstaticcompresscache"), 0)
 
148	npaths = len(staticgzcache.paths)
 
150		t.Fatalf("%d file(s) in staticgzcache after setting max size to 0, expected 0", npaths)
 
152	loadStaticGzipCache(mox.DataDirPath("tmp/httpstaticcompresscache"), 0)
 
153	npaths = len(staticgzcache.paths)
 
155		t.Fatalf("%d file(s) in staticgzcache after setting max size to 0 and reloading from disk, expected 0", npaths)
 
159func TestWebsocket(t *testing.T) {
 
160	os.RemoveAll("../testdata/websocket/data")
 
161	mox.ConfigStaticPath = filepath.FromSlash("../testdata/websocket/mox.conf")
 
162	mox.ConfigDynamicPath = filepath.Join(filepath.Dir(mox.ConfigStaticPath), "domains.conf")
 
163	mox.MustLoadConfig(true, false)
 
165	srv := &serve{Webserver: true}
 
167	var handler http.Handler // Active handler during test.
 
168	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
169		handler.ServeHTTP(w, r)
 
172	defer backend.Close()
 
173	backendURL, err := url.Parse(backend.URL)
 
175		t.Fatalf("parsing backend url: %v", err)
 
177	backendURL.Path = "/"
 
179	// warning: it is not normally allowed to access the dynamic config without lock. don't propagate accesses like this!
 
180	mox.Conf.Dynamic.WebHandlers[len(mox.Conf.Dynamic.WebHandlers)-1].WebForward.TargetURL = backendURL
 
182	server := httptest.NewServer(srv)
 
185	serverURL, err := url.Parse(server.URL)
 
186	tcheck(t, err, "parsing server url")
 
187	_, port, err := net.SplitHostPort(serverURL.Host)
 
188	tcheck(t, err, "parsing host port in server url")
 
189	wsurl := fmt.Sprintf("ws://%s/ws/", net.JoinHostPort("localhost", port))
 
191	handler = websocket.Handler(func(c *websocket.Conn) {
 
195	// Test a correct websocket connection.
 
196	wsconn, err := websocket.Dial(wsurl, "ignored", "http://ignored.example")
 
197	tcheck(t, err, "websocket dial")
 
198	_, err = fmt.Fprint(wsconn, "test")
 
199	tcheck(t, err, "write to websocket")
 
200	buf := make([]byte, 128)
 
201	n, err := wsconn.Read(buf)
 
202	tcheck(t, err, "read from websocket")
 
203	if string(buf[:n]) != "test" {
 
204		t.Fatalf(`got websocket data %q, expected "test"`, buf[:n])
 
207	tcheck(t, err, "closing websocket connection")
 
209	// Test with server.ServeHTTP directly.
 
210	test := func(method string, reqhdrs map[string]string, expCode int, expHeaders map[string]string) {
 
213		req := httptest.NewRequest(method, wsurl, nil)
 
214		for k, v := range reqhdrs {
 
217		rw := httptest.NewRecorder()
 
218		rw.Body = &bytes.Buffer{}
 
219		srv.ServeHTTP(rw, req)
 
221		if resp.StatusCode != expCode {
 
222			t.Fatalf("got statuscode %d, expected %d", resp.StatusCode, expCode)
 
224		for k, v := range expHeaders {
 
225			if xv := resp.Header.Get(k); xv != v {
 
226				t.Fatalf("got %q for header %q, expected %q", xv, k, v)
 
231	wsreqhdrs := map[string]string{
 
232		"Upgrade":               "keep-alive, websocket",
 
233		"Connection":            "X, Upgrade",
 
234		"Sec-Websocket-Version": "13",
 
235		"Sec-Websocket-Key":     "AAAAAAAAAAAAAAAAAAAAAA==",
 
238	test("POST", wsreqhdrs, http.StatusBadRequest, nil)
 
240	clone := func(m map[string]string) map[string]string {
 
241		r := map[string]string{}
 
242		for k, v := range m {
 
248	hdrs := clone(wsreqhdrs)
 
249	hdrs["Sec-Websocket-Version"] = "14"
 
250	test("GET", hdrs, http.StatusBadRequest, map[string]string{"Sec-Websocket-Version": "13"})
 
252	httpurl := fmt.Sprintf("http://%s/ws/", net.JoinHostPort("localhost", port))
 
254	// Must now do actual HTTP requests and read the HTTP response. Cannot call
 
255	// ServeHTTP because ResponseRecorder is not a http.Hijacker.
 
256	test = func(method string, reqhdrs map[string]string, expCode int, expHeaders map[string]string) {
 
259		req, err := http.NewRequest(method, httpurl, nil)
 
260		tcheck(t, err, "http newrequest")
 
261		for k, v := range reqhdrs {
 
264		resp, err := http.DefaultClient.Do(req)
 
265		tcheck(t, err, "http transaction")
 
266		if resp.StatusCode != expCode {
 
267			t.Fatalf("got statuscode %d, expected %d", resp.StatusCode, expCode)
 
269		for k, v := range expHeaders {
 
270			if xv := resp.Header.Get(k); xv != v {
 
271				t.Fatalf("got %q for header %q, expected %q", xv, k, v)
 
276	hdrs = clone(wsreqhdrs)
 
277	hdrs["Sec-Websocket-Key"] = "malformed"
 
278	test("GET", hdrs, http.StatusBadRequest, nil)
 
280	hdrs = clone(wsreqhdrs)
 
281	hdrs["Sec-Websocket-Key"] = "c2hvcnQK" // "short"
 
282	test("GET", hdrs, http.StatusBadRequest, nil)
 
284	// Not responding with a 101, but with regular 200 OK response.
 
285	handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
286		http.Error(w, "bad", http.StatusOK)
 
288	test("GET", wsreqhdrs, http.StatusBadRequest, nil)
 
290	// Respond with 101, but other websocket response headers missing.
 
291	handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
292		w.WriteHeader(http.StatusSwitchingProtocols)
 
294	test("GET", wsreqhdrs, http.StatusBadRequest, nil)
 
296	// With Upgrade: websocket, without Connection: Upgrade
 
297	handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
298		w.Header().Set("Upgrade", "websocket")
 
299		w.WriteHeader(http.StatusSwitchingProtocols)
 
301	test("GET", wsreqhdrs, http.StatusBadRequest, nil)
 
303	// With malformed Sec-WebSocket-Accept response header.
 
304	handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
306		h.Set("Upgrade", "websocket")
 
307		h.Set("Connection", "Upgrade")
 
308		h.Set("Sec-WebSocket-Accept", "malformed")
 
309		w.WriteHeader(http.StatusSwitchingProtocols)
 
311	test("GET", wsreqhdrs, http.StatusBadRequest, nil)
 
313	// With malformed Sec-WebSocket-Accept response header.
 
314	handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
316		h.Set("Upgrade", "websocket")
 
317		h.Set("Connection", "Upgrade")
 
318		h.Set("Sec-WebSocket-Accept", "YmFk") // "bad"
 
319		w.WriteHeader(http.StatusSwitchingProtocols)
 
321	test("GET", wsreqhdrs, http.StatusBadRequest, nil)
 
324	wsresphdrs := map[string]string{
 
325		"Connection":           "Upgrade",
 
326		"Upgrade":              "websocket",
 
327		"Sec-Websocket-Accept": "ICX+Yqv66kxgM0FcWaLWlFLwTAI=",
 
330	handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 
332		h.Set("Upgrade", "websocket")
 
333		h.Set("Connection", "Upgrade")
 
334		h.Set("Sec-WebSocket-Accept", "ICX+Yqv66kxgM0FcWaLWlFLwTAI=")
 
335		w.WriteHeader(http.StatusSwitchingProtocols)
 
337	test("GET", wsreqhdrs, http.StatusSwitchingProtocols, wsresphdrs)