1package http
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 "crypto/sha1"
8 "crypto/tls"
9 "encoding/base64"
10 "errors"
11 "fmt"
12 htmltemplate "html/template"
13 "io"
14 "io/fs"
15 golog "log"
16 "log/slog"
17 "net"
18 "net/http"
19 "net/http/httputil"
20 "net/textproto"
21 "net/url"
22 "os"
23 "path/filepath"
24 "sort"
25 "strings"
26 "syscall"
27 "time"
28
29 "github.com/mjl-/mox/config"
30 "github.com/mjl-/mox/dns"
31 "github.com/mjl-/mox/mlog"
32 "github.com/mjl-/mox/mox-"
33 "github.com/mjl-/mox/moxio"
34)
35
36func recvid(r *http.Request) string {
37 cid := mox.CidFromCtx(r.Context())
38 if cid <= 0 {
39 return ""
40 }
41 return " (id " + mox.ReceivedID(cid) + ")"
42}
43
44// WebHandle serves an HTTP request by going through the list of WebHandlers,
45// check if there is a domain+path match, and running the handler if so.
46// WebHandle runs after the built-in handlers for mta-sts, autoconfig, etc.
47// If no handler matched, false is returned.
48// WebHandle sets w.Name to that of the matching handler.
49func WebHandle(w *loggingWriter, r *http.Request, host dns.Domain) (handled bool) {
50 redirects, handlers := mox.Conf.WebServer()
51
52 for from, to := range redirects {
53 if host != from {
54 continue
55 }
56 u := r.URL
57 u.Scheme = "https"
58 u.Host = to.Name()
59 w.Handler = "(domainredirect)"
60 http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
61 return true
62 }
63
64 for _, h := range handlers {
65 if host != h.DNSDomain {
66 continue
67 }
68 loc := h.Path.FindStringIndex(r.URL.Path)
69 if loc == nil {
70 continue
71 }
72 s := loc[0]
73 e := loc[1]
74 path := r.URL.Path[s:e]
75
76 if r.TLS == nil && !h.DontRedirectPlainHTTP {
77 u := *r.URL
78 u.Scheme = "https"
79 u.Host = h.DNSDomain.Name()
80 w.Handler = h.Name
81 w.Compress = h.Compress
82 http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
83 return true
84 }
85
86 // We don't want the loggingWriter to override the static handler's decisions to compress.
87 w.Compress = h.Compress
88 if h.WebStatic != nil && HandleStatic(h.WebStatic, h.Compress, w, r) {
89 w.Handler = h.Name
90 return true
91 }
92 if h.WebRedirect != nil && HandleRedirect(h.WebRedirect, w, r) {
93 w.Handler = h.Name
94 return true
95 }
96 if h.WebForward != nil && HandleForward(h.WebForward, w, r, path) {
97 w.Handler = h.Name
98 return true
99 }
100 }
101 w.Compress = false
102 return false
103}
104
105var lsTemplate = htmltemplate.Must(htmltemplate.New("ls").Parse(`<!doctype html>
106<html>
107 <head>
108 <meta charset="utf-8" />
109 <meta name="viewport" content="width=device-width, initial-scale=1" />
110 <title>ls</title>
111 <style>
112body, html { padding: 1em; font-size: 16px; }
113* { font-size: inherit; font-family: ubuntu, lato, sans-serif; margin: 0; padding: 0; box-sizing: border-box; }
114h1 { margin-bottom: 1ex; font-size: 1.2rem; }
115table td, table th { padding: .2em .5em; }
116table > tbody > tr:nth-child(odd) { background-color: #f8f8f8; }
117[title] { text-decoration: underline; text-decoration-style: dotted; }
118 </style>
119 </head>
120 <body>
121 <h1>ls</h1>
122 <table>
123 <thead>
124 <tr>
125 <th>Size in MB</th>
126 <th>Modified (UTC)</th>
127 <th>Name</th>
128 </tr>
129 </thead>
130 <tbody>
131 {{ if not .Files }}
132 <tr><td colspan="3">No files.</td></tr>
133 {{ end }}
134 {{ range .Files }}
135 <tr>
136 <td title="{{ .Size }} bytes" style="text-align: right">{{ .SizeReadable }}{{ if .SizePad }}<span style="visibility:hidden">.  </span>{{ end }}</td>
137 <td>{{ .Modified }}</td>
138 <td><a style="display: block" href="{{ .Name }}">{{ .Name }}</a></td>
139 </tr>
140 {{ end }}
141 </tbody>
142 </table>
143 </body>
144</html>
145`))
146
147// HandleStatic serves static files. If a directory is requested and the URL
148// path doesn't end with a slash, a response with a redirect to the URL path with trailing
149// slash is written. If a directory is requested and an index.html exists, that
150// file is returned. Otherwise, for directories with ListFiles configured, a
151// directory listing is returned.
152func HandleStatic(h *config.WebStatic, compress bool, w http.ResponseWriter, r *http.Request) (handled bool) {
153 log := func() mlog.Log {
154 return pkglog.WithContext(r.Context())
155 }
156 if r.Method != "GET" && r.Method != "HEAD" {
157 if h.ContinueNotFound {
158 // Give another handler that is presumbly configured, for the same path, a chance.
159 // E.g. an app that may generate this file for future requests to pick up.
160 return false
161 }
162 http.Error(w, "405 - method not allowed", http.StatusMethodNotAllowed)
163 return true
164 }
165
166 var fspath string
167 if h.StripPrefix != "" {
168 if !strings.HasPrefix(r.URL.Path, h.StripPrefix) {
169 if h.ContinueNotFound {
170 // We haven't handled this request, try a next WebHandler in the list.
171 return false
172 }
173 http.NotFound(w, r)
174 return true
175 }
176 fspath = filepath.Join(h.Root, strings.TrimPrefix(r.URL.Path, h.StripPrefix))
177 } else {
178 fspath = filepath.Join(h.Root, r.URL.Path)
179 }
180 // fspath will not have a trailing slash anymore, we'll correct for it
181 // later when the path turns out to be file instead of a directory.
182
183 serveFile := func(name string, fi fs.FileInfo, content *os.File) {
184 // ServeContent only sets a content-type if not already present in the response headers.
185 hdr := w.Header()
186 for k, v := range h.ResponseHeaders {
187 hdr.Add(k, v)
188 }
189 // We transparently compress here, but still use ServeContent, because it handles
190 // conditional requests, range requests. It's a bit of a hack, but on first write
191 // to staticgzcacheReplacer where we are compressing, we write the full compressed
192 // file instead, and return an error to ServeContent so it stops. We still have all
193 // the useful behaviour (status code and headers) from ServeContent.
194 xw := w
195 if compress && acceptsGzip(r) && compressibleContent(content) {
196 xw = &staticgzcacheReplacer{w, r, content.Name(), content, fi.ModTime(), fi.Size(), 0, false}
197 } else {
198 w.(*loggingWriter).Compress = false
199 }
200 http.ServeContent(xw, r, name, fi.ModTime(), content)
201 }
202
203 f, err := os.Open(fspath)
204 if err != nil {
205 if os.IsNotExist(err) || errors.Is(err, syscall.ENOTDIR) {
206 if h.ContinueNotFound {
207 // We haven't handled this request, try a next WebHandler in the list.
208 return false
209 }
210 http.NotFound(w, r)
211 return true
212 } else if os.IsPermission(err) {
213 // If we tried opening a directory, we may not have permission to read it, but
214 // still access files inside it (execute bit), such as index.html. So try to serve it.
215 index, err := os.Open(filepath.Join(fspath, "index.html"))
216 if err == nil {
217 defer index.Close()
218 var ifi os.FileInfo
219 ifi, err = index.Stat()
220 if err != nil {
221 log().Errorx("stat index.html in directory we cannot list", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
222 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
223 return true
224 }
225 w.Header().Set("Content-Type", "text/html; charset=utf-8")
226 serveFile("index.html", ifi, index)
227 return true
228 }
229 http.Error(w, "403 - permission denied", http.StatusForbidden)
230 return true
231 }
232 log().Errorx("open file for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
233 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
234 return true
235 }
236 defer f.Close()
237
238 fi, err := f.Stat()
239 if err != nil {
240 log().Errorx("stat file for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
241 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
242 return true
243 }
244 // Redirect if the local path is a directory.
245 if fi.IsDir() && !strings.HasSuffix(r.URL.Path, "/") {
246 http.Redirect(w, r, r.URL.Path+"/", http.StatusTemporaryRedirect)
247 return true
248 } else if !fi.IsDir() && strings.HasSuffix(r.URL.Path, "/") {
249 if h.ContinueNotFound {
250 return false
251 }
252 http.NotFound(w, r)
253 return true
254 }
255
256 if fi.IsDir() {
257 index, err := os.Open(filepath.Join(fspath, "index.html"))
258 if err != nil && os.IsPermission(err) {
259 http.Error(w, "403 - permission denied", http.StatusForbidden)
260 return true
261 } else if err != nil && os.IsNotExist(err) && !h.ListFiles {
262 if h.ContinueNotFound {
263 return false
264 }
265 http.Error(w, "403 - permission denied", http.StatusForbidden)
266 return true
267 } else if err == nil {
268 defer index.Close()
269 var ifi os.FileInfo
270 ifi, err = index.Stat()
271 if err == nil {
272 w.Header().Set("Content-Type", "text/html; charset=utf-8")
273 serveFile("index.html", ifi, index)
274 return true
275 }
276 }
277 if !os.IsNotExist(err) {
278 log().Errorx("stat for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
279 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
280 return true
281 }
282
283 type File struct {
284 Name string
285 Size int64
286 SizeReadable string
287 SizePad bool // Whether the size needs padding because it has no decimal point.
288 Modified string
289 }
290 files := []File{}
291 if r.URL.Path != "/" {
292 files = append(files, File{"..", 0, "", false, ""})
293 }
294 for {
295 l, err := f.Readdir(1000)
296 for _, e := range l {
297 mb := float64(e.Size()) / (1024 * 1024)
298 var size string
299 var sizepad bool
300 if !e.IsDir() {
301 if mb >= 10 {
302 size = fmt.Sprintf("%d", int64(mb))
303 sizepad = true
304 } else {
305 size = fmt.Sprintf("%.2f", mb)
306 }
307 }
308 const dateTime = "2006-01-02 15:04:05" // time.DateTime, but only since go1.20.
309 modified := e.ModTime().UTC().Format(dateTime)
310 f := File{e.Name(), e.Size(), size, sizepad, modified}
311 if e.IsDir() {
312 f.Name += "/"
313 }
314 files = append(files, f)
315 }
316 if err == io.EOF {
317 break
318 } else if err != nil {
319 log().Errorx("reading directory for file listing", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
320 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
321 return true
322 }
323 }
324 sort.Slice(files, func(i, j int) bool {
325 return files[i].Name < files[j].Name
326 })
327 hdr := w.Header()
328 hdr.Set("Content-Type", "text/html; charset=utf-8")
329 for k, v := range h.ResponseHeaders {
330 if !strings.EqualFold(k, "content-type") {
331 hdr.Add(k, v)
332 }
333 }
334 err = lsTemplate.Execute(w, map[string]any{"Files": files})
335 if err != nil && !moxio.IsClosed(err) {
336 log().Errorx("executing directory listing template", err)
337 }
338 return true
339 }
340
341 serveFile(fspath, fi, f)
342 return true
343}
344
345// HandleRedirect writes a response with an HTTP redirect.
346func HandleRedirect(h *config.WebRedirect, w http.ResponseWriter, r *http.Request) (handled bool) {
347 var dstpath string
348 if h.OrigPath == nil {
349 // No path rewrite necessary.
350 dstpath = r.URL.Path
351 } else if !h.OrigPath.MatchString(r.URL.Path) {
352 http.NotFound(w, r)
353 return true
354 } else {
355 dstpath = h.OrigPath.ReplaceAllString(r.URL.Path, h.ReplacePath)
356 }
357
358 u := *r.URL
359 u.Opaque = ""
360 u.RawPath = ""
361 u.OmitHost = false
362 if h.URL != nil {
363 u.Scheme = h.URL.Scheme
364 u.Host = h.URL.Host
365 u.ForceQuery = h.URL.ForceQuery
366 u.RawQuery = h.URL.RawQuery
367 u.Fragment = h.URL.Fragment
368 if r.URL.RawQuery != "" {
369 if u.RawQuery != "" {
370 u.RawQuery += "&"
371 }
372 u.RawQuery += r.URL.RawQuery
373 }
374 }
375 u.Path = dstpath
376 code := http.StatusPermanentRedirect
377 if h.StatusCode != 0 {
378 code = h.StatusCode
379 }
380
381 // If we would be redirecting to the same scheme,host,path, we would get here again
382 // causing a redirect loop. Instead, this causes this redirect to not match,
383 // allowing to try the next WebHandler. This can be used to redirect all plain http
384 // requests to https.
385 reqscheme := "http"
386 if r.TLS != nil {
387 reqscheme = "https"
388 }
389 if reqscheme == u.Scheme && r.Host == u.Host && r.URL.Path == u.Path {
390 return false
391 }
392
393 http.Redirect(w, r, u.String(), code)
394 return true
395}
396
397// HandleForward handles a request by forwarding it to another webserver and
398// passing the response on. I.e. a reverse proxy. It handles websocket
399// connections by monitoring the websocket handshake and then just passing along the
400// websocket frames.
401func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
402 log := func() mlog.Log {
403 return pkglog.WithContext(r.Context())
404 }
405
406 xr := *r
407 r = &xr
408 if h.StripPath {
409 u := *r.URL
410 u.Path = r.URL.Path[len(path):]
411 if !strings.HasPrefix(u.Path, "/") {
412 u.Path = "/" + u.Path
413 }
414 u.RawPath = ""
415 r.URL = &u
416 }
417
418 // Remove any forwarded headers passed in by client.
419 hdr := http.Header{}
420 for k, vl := range r.Header {
421 if k == "Forwarded" || k == "X-Forwarded" || strings.HasPrefix(k, "X-Forwarded-") {
422 continue
423 }
424 hdr[k] = vl
425 }
426 r.Header = hdr
427
428 // Add our own X-Forwarded headers. ReverseProxy will add X-Forwarded-For.
429 r.Header["X-Forwarded-Host"] = []string{r.Host}
430 proto := "http"
431 if r.TLS != nil {
432 proto = "https"
433 }
434 r.Header["X-Forwarded-Proto"] = []string{proto}
435 // note: We are not using "ws" or "wss" for websocket. The request we are
436 // forwarding is http(s), and we don't yet know if the backend even supports
437 // websockets.
438
439 // todo: add Forwarded header? is anyone using it?
440
441 // If we see an Upgrade: websocket, we're going to assume the client needs
442 // websocket and only attempt to talk websocket with the backend. If the backend
443 // doesn't do websocket, we'll send back a "bad request" response. For other values
444 // of Upgrade, we don't do anything special.
445 // https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml
446 // Upgrade: ../rfc/9110:2798
447 // Upgrade headers are not for http/1.0, ../rfc/9110:2880
448 // Websocket client "handshake" is described at ../rfc/6455:1134
449 upgrade := r.Header.Get("Upgrade")
450 if upgrade != "" && !(r.ProtoMajor == 1 && r.ProtoMinor == 0) {
451 // Websockets have case-insensitive string "websocket".
452 for _, s := range strings.Split(upgrade, ",") {
453 if strings.EqualFold(textproto.TrimString(s), "websocket") {
454 forwardWebsocket(h, w, r, path)
455 return true
456 }
457 }
458 }
459
460 // ReverseProxy will append any remaining path to the configured target URL.
461 proxy := httputil.NewSingleHostReverseProxy(h.TargetURL)
462 proxy.FlushInterval = time.Duration(-1) // Flush after each write.
463 proxy.ErrorLog = golog.New(mlog.LogWriter(mlog.New("net/http/httputil", nil).WithContext(r.Context()), mlog.LevelDebug, "reverseproxy error"), "", 0)
464 proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
465 if errors.Is(err, context.Canceled) {
466 log().Debugx("forwarding request to backend webserver", err, slog.Any("url", r.URL))
467 return
468 }
469 log().Errorx("forwarding request to backend webserver", err, slog.Any("url", r.URL))
470 if os.IsTimeout(err) {
471 http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
472 } else {
473 http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
474 }
475 }
476 whdr := w.Header()
477 for k, v := range h.ResponseHeaders {
478 whdr.Add(k, v)
479 }
480 proxy.ServeHTTP(w, r)
481 return true
482}
483
484var errResponseNotWebsocket = errors.New("not a valid websocket response to request")
485var errNotImplemented = errors.New("functionality not yet implemented")
486
487// Request has an Upgrade: websocket header. Check more websocketiness about the
488// request. If it looks good, we forward it to the backend. If the backend responds
489// with a valid websocket response, indicating it is indeed a websocket server, we
490// pass the response along and start copying data between the client and the
491// backend. We don't look at the frames and payloads. The backend already needs to
492// know enough websocket to handle the frames. It wouldn't necessarily hurt to
493// monitor the frames too, and check if they are valid, but it's quite a bit of
494// work for little benefit. Besides, the whole point of websockets is to exchange
495// bytes without HTTP being in the way, so let's do that.
496func forwardWebsocket(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
497 log := func() mlog.Log {
498 return pkglog.WithContext(r.Context())
499 }
500
501 lw := w.(*loggingWriter)
502 lw.WebsocketRequest = true // For correct protocol in metrics.
503
504 // We check the requested websocket version first. A future websocket version may
505 // have different request requirements.
506 // ../rfc/6455:1160
507 wsversion := r.Header.Get("Sec-WebSocket-Version")
508 if wsversion != "13" {
509 // Indicate we only support version 13. Should get a client from the future to fall back to version 13.
510 // ../rfc/6455:1435
511 w.Header().Set("Sec-WebSocket-Version", "13")
512 http.Error(w, "400 - bad request - websockets only supported with version 13"+recvid(r), http.StatusBadRequest)
513 lw.error(fmt.Errorf("Sec-WebSocket-Version %q not supported", wsversion))
514 return true
515 }
516
517 // ../rfc/6455:1143
518 if r.Method != "GET" {
519 http.Error(w, "400 - bad request - websockets only allowed with method GET"+recvid(r), http.StatusBadRequest)
520 lw.error(fmt.Errorf("websocket request only allowed with method GET"))
521 return true
522 }
523
524 // ../rfc/6455:1153
525 var connectionUpgrade bool
526 for _, s := range strings.Split(r.Header.Get("Connection"), ",") {
527 if strings.EqualFold(textproto.TrimString(s), "upgrade") {
528 connectionUpgrade = true
529 break
530 }
531 }
532 if !connectionUpgrade {
533 http.Error(w, "400 - bad request - connection header must be \"upgrade\""+recvid(r), http.StatusBadRequest)
534 lw.error(fmt.Errorf(`connection header is %q, must be "upgrade"`, r.Header.Get("Connection")))
535 return true
536 }
537
538 // ../rfc/6455:1156
539 wskey := r.Header.Get("Sec-WebSocket-Key")
540 key, err := base64.StdEncoding.DecodeString(wskey)
541 if err != nil || len(key) != 16 {
542 http.Error(w, "400 - bad request - websockets requires Sec-WebSocket-Key with 16 bytes base64-encoded value"+recvid(r), http.StatusBadRequest)
543 lw.error(fmt.Errorf("bad Sec-WebSocket-Key %q, must be 16 byte base64-encoded value", wskey))
544 return true
545 }
546
547 // ../rfc/6455:1162
548 // We don't look at the origin header. The backend needs to handle it, if it thinks
549 // that helps...
550 // We also don't look at Sec-WebSocket-Protocol and Sec-WebSocket-Extensions. The
551 // backend can set them, but it doesn't influence our forwarding of the data.
552
553 // If this is not a hijacker, there is not point in connecting to the backend.
554 hj, ok := lw.W.(http.Hijacker)
555 var cbr *bufio.ReadWriter
556 if !ok {
557 log().Info("cannot turn http connection into tcp connection (http.Hijacker)")
558 http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
559 lw.error(fmt.Errorf("connection not a http.Hijacker (%T)", lw.W))
560 return
561 }
562
563 freq := *r
564 freq.Proto = "HTTP/1.1"
565 freq.ProtoMajor = 1
566 freq.ProtoMinor = 1
567 fresp, beconn, err := websocketTransact(r.Context(), h.TargetURL, &freq)
568 if err != nil {
569 if errors.Is(err, errResponseNotWebsocket) {
570 http.Error(w, "400 - bad request - websocket not supported"+recvid(r), http.StatusBadRequest)
571 } else if errors.Is(err, errNotImplemented) {
572 http.Error(w, "501 - not implemented - "+err.Error()+recvid(r), http.StatusNotImplemented)
573 } else if os.IsTimeout(err) {
574 http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
575 } else {
576 http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
577 }
578 lw.error(err)
579 return
580 }
581 defer func() {
582 if beconn != nil {
583 beconn.Close()
584 }
585 }()
586
587 // Hijack the client connection so we can write the response ourselves, and start
588 // copying the websocket frames.
589 var cconn net.Conn
590 cconn, cbr, err = hj.Hijack()
591 if err != nil {
592 log().Debugx("cannot turn http transaction into websocket connection", err)
593 http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
594 lw.error(err)
595 return
596 }
597 defer func() {
598 if cconn != nil {
599 cconn.Close()
600 }
601 }()
602
603 // Below this point, we can no longer write to the ResponseWriter.
604
605 // Mark as websocket response, for logging.
606 lw.WebsocketResponse = true
607 lw.setStatusCode(fresp.StatusCode)
608
609 for k, v := range h.ResponseHeaders {
610 fresp.Header.Add(k, v)
611 }
612
613 // Write the response to the client, completing its websocket handshake.
614 if err := fresp.Write(cconn); err != nil {
615 lw.error(fmt.Errorf("writing websocket response to client: %w", err))
616 return
617 }
618
619 errc := make(chan error, 1)
620
621 // Copy from client to backend.
622 go func() {
623 buf, err := cbr.Peek(cbr.Reader.Buffered())
624 if err != nil {
625 errc <- err
626 return
627 }
628 if len(buf) > 0 {
629 n, err := beconn.Write(buf)
630 if err != nil {
631 errc <- err
632 return
633 }
634 lw.SizeFromClient += int64(n)
635 }
636 n, err := io.Copy(beconn, cconn)
637 lw.SizeFromClient += n
638 errc <- err
639 }()
640
641 // Copy from backend to client.
642 go func() {
643 n, err := io.Copy(cconn, beconn)
644 lw.SizeToClient = n
645 errc <- err
646 }()
647
648 // Stop and close connection on first error from either size, typically a closed
649 // connection whose closing was already announced with a websocket frame.
650 lw.error(<-errc)
651 // Close connections so other goroutine stops as well.
652 cconn.Close()
653 beconn.Close()
654 // Wait for goroutine so it has updated the logWriter.Size*Client fields before we
655 // continue with logging.
656 <-errc
657 cconn = nil
658 return true
659}
660
661func websocketTransact(ctx context.Context, targetURL *url.URL, r *http.Request) (rresp *http.Response, rconn net.Conn, rerr error) {
662 log := func() mlog.Log {
663 return pkglog.WithContext(r.Context())
664 }
665
666 // Dial the backend, possibly doing TLS. We assume the net/http DefaultTransport is
667 // unmodified.
668 transport := http.DefaultTransport.(*http.Transport)
669
670 // We haven't implemented using a proxy for websocket requests yet. If we need one,
671 // return an error instead of trying to connect directly, which would be a
672 // potential security issue.
673 treq := *r
674 treq.URL = targetURL
675 if purl, err := transport.Proxy(&treq); err != nil {
676 return nil, nil, fmt.Errorf("determining proxy for websocket backend connection: %w", err)
677 } else if purl != nil {
678 return nil, nil, fmt.Errorf("%w: proxy required for websocket connection to backend", errNotImplemented) // todo: implement?
679 }
680
681 host, port, err := net.SplitHostPort(targetURL.Host)
682 if err != nil {
683 host = targetURL.Host
684 if targetURL.Scheme == "https" {
685 port = "443"
686 } else {
687 port = "80"
688 }
689 }
690 addr := net.JoinHostPort(host, port)
691 conn, err := transport.DialContext(r.Context(), "tcp", addr)
692 if err != nil {
693 return nil, nil, fmt.Errorf("dial: %w", err)
694 }
695 if targetURL.Scheme == "https" {
696 tlsconn := tls.Client(conn, transport.TLSClientConfig)
697 ctx, cancel := context.WithTimeout(r.Context(), transport.TLSHandshakeTimeout)
698 defer cancel()
699 if err := tlsconn.HandshakeContext(ctx); err != nil {
700 return nil, nil, fmt.Errorf("tls handshake: %w", err)
701 }
702 conn = tlsconn
703 }
704 defer func() {
705 if rerr != nil {
706 conn.Close()
707 }
708 }()
709
710 // todo: make timeout configurable?
711 if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
712 log().Check(err, "set deadline for websocket request to backend")
713 }
714
715 // Set clean connection headers.
716 removeHopByHopHeaders(r.Header)
717 r.Header.Set("Connection", "Upgrade")
718 r.Header.Set("Upgrade", "websocket")
719
720 // Write the websocket request to the backend.
721 if err := r.Write(conn); err != nil {
722 return nil, nil, fmt.Errorf("writing request to backend: %w", err)
723 }
724
725 // Read response from backend.
726 br := bufio.NewReader(conn)
727 resp, err := http.ReadResponse(br, r)
728 if err != nil {
729 return nil, nil, fmt.Errorf("reading response from backend: %w", err)
730 }
731 defer func() {
732 if rerr != nil {
733 resp.Body.Close()
734 }
735 }()
736 if err := conn.SetDeadline(time.Time{}); err != nil {
737 log().Check(err, "clearing deadline on websocket connection to backend")
738 }
739
740 // Check that the response from the backend server indicates it is websocket. If
741 // not, don't pass the backend response, but an error that websocket is not
742 // appropriate.
743 if err := checkWebsocketResponse(resp, r); err != nil {
744 return resp, nil, err
745 }
746
747 // note: net/http.Response.Body documents that it implements io.Writer for a
748 // status: 101 response. But that's not the case when the response has been read
749 // with http.ReadResponse. We'll write to the connection directly.
750
751 buf, err := br.Peek(br.Buffered())
752 if err != nil {
753 return resp, nil, fmt.Errorf("peek at buffered data written by backend: %w", err)
754 }
755 return resp, websocketConn{io.MultiReader(bytes.NewReader(buf), conn), conn}, nil
756}
757
758// A net.Conn but with reads coming from an io multireader (due to buffered reader
759// needed for http.ReadResponse).
760type websocketConn struct {
761 r io.Reader
762 net.Conn
763}
764
765func (c websocketConn) Read(buf []byte) (int, error) {
766 return c.r.Read(buf)
767}
768
769// Check that an HTTP response (from a backend) is a valid websocket response, i.e.
770// that it accepts the WebSocket "upgrade".
771// ../rfc/6455:1299
772func checkWebsocketResponse(resp *http.Response, req *http.Request) error {
773 if resp.StatusCode != 101 {
774 return fmt.Errorf("%w: response http status not 101 but %s", errResponseNotWebsocket, resp.Status)
775 }
776 if upgrade := resp.Header.Get("Upgrade"); !strings.EqualFold(upgrade, "websocket") {
777 return fmt.Errorf(`%w: response http status is 101, but Upgrade header is %q, should be "websocket"`, errResponseNotWebsocket, upgrade)
778 }
779 if connection := resp.Header.Get("Connection"); !strings.EqualFold(connection, "upgrade") {
780 return fmt.Errorf(`%w: response http status is 101, Upgrade is websocket, but Connection header is %q, should be "Upgrade"`, errResponseNotWebsocket, connection)
781 }
782 accept, err := base64.StdEncoding.DecodeString(resp.Header.Get("Sec-WebSocket-Accept"))
783 if err != nil {
784 return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but Sec-WebSocket-Accept header is not valid base64: %v`, errResponseNotWebsocket, err)
785 }
786 exp := sha1.Sum([]byte(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
787 if !bytes.Equal(accept, exp[:]) {
788 return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but backend Sec-WebSocket-Accept value does not match`, errResponseNotWebsocket)
789 }
790 // We don't have requirements for the other Sec-WebSocket headers. ../rfc/6455:1340
791 return nil
792}
793
794// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
795// Hop-by-hop headers. These are removed when sent to the backend.
796// As of RFC 7230, hop-by-hop headers are required to appear in the
797// Connection header field. These are the headers defined by the
798// obsoleted RFC 2616 (section 13.5.1) and are used for backward
799// compatibility.
800// ../rfc/2616:5128
801var hopHeaders = []string{
802 "Connection",
803 "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
804 "Keep-Alive",
805 "Proxy-Authenticate",
806 "Proxy-Authorization",
807 "Te", // canonicalized version of "TE"
808 "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
809 "Transfer-Encoding",
810 "Upgrade",
811}
812
813// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
814// removeHopByHopHeaders removes hop-by-hop headers.
815func removeHopByHopHeaders(h http.Header) {
816 // RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
817 // ../rfc/7230:2817
818 for _, f := range h["Connection"] {
819 for _, sf := range strings.Split(f, ",") {
820 if sf = textproto.TrimString(sf); sf != "" {
821 h.Del(sf)
822 }
823 }
824 }
825 // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
826 // This behavior is superseded by the RFC 7230 Connection header, but
827 // preserve it for backwards compatibility.
828 // ../rfc/2616:5128
829 for _, f := range hopHeaders {
830 h.Del(f)
831 }
832}
833