1package http
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 "crypto/sha1"
8 "crypto/tls"
9 "encoding/base64"
10 "errors"
11 "fmt"
12 htmltemplate "html/template"
13 "io"
14 "io/fs"
15 golog "log"
16 "log/slog"
17 "net"
18 "net/http"
19 "net/textproto"
20 "net/url"
21 "os"
22 "path/filepath"
23 "sort"
24 "strings"
25 "syscall"
26 "time"
27
28 "github.com/mjl-/mox/config"
29 "github.com/mjl-/mox/dns"
30 "github.com/mjl-/mox/mlog"
31 "github.com/mjl-/mox/mox-"
32)
33
34func recvid(r *http.Request) string {
35 cid := mox.CidFromCtx(r.Context())
36 if cid <= 0 {
37 return ""
38 }
39 return " (id " + mox.ReceivedID(cid) + ")"
40}
41
42// WebHandle serves an HTTP request by going through the list of WebHandlers,
43// check if there is a domain+path match, and running the handler if so.
44// WebHandle runs after the built-in handlers for mta-sts, autoconfig, etc.
45// If no handler matched, false is returned.
46// WebHandle sets w.Name to that of the matching handler.
47func WebHandle(w *loggingWriter, r *http.Request, host dns.IPDomain) (handled bool) {
48 conf := mox.Conf.DynamicConfig()
49 redirects := conf.WebDNSDomainRedirects
50 handlers := conf.WebHandlers
51
52 for from, to := range redirects {
53 if host.Domain != from {
54 continue
55 }
56 u := r.URL
57 u.Scheme = "https"
58 u.Host = to.Name()
59 w.Handler = "(domainredirect)"
60 http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
61 return true
62 }
63
64 for _, h := range handlers {
65 if host.Domain != h.DNSDomain {
66 continue
67 }
68 loc := h.Path.FindStringIndex(r.URL.Path)
69 if loc == nil {
70 continue
71 }
72 s := loc[0]
73 e := loc[1]
74 path := r.URL.Path[s:e]
75
76 if r.TLS == nil && !h.DontRedirectPlainHTTP {
77 u := *r.URL
78 u.Scheme = "https"
79 u.Host = h.DNSDomain.Name()
80 w.Handler = h.Name
81 w.Compress = h.Compress
82 http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
83 return true
84 }
85
86 // We don't want the loggingWriter to override the static handler's decisions to compress.
87 w.Compress = h.Compress
88 if h.WebStatic != nil && HandleStatic(h.WebStatic, h.Compress, w, r) {
89 w.Handler = h.Name
90 return true
91 }
92 if h.WebRedirect != nil && HandleRedirect(h.WebRedirect, w, r) {
93 w.Handler = h.Name
94 return true
95 }
96 if h.WebForward != nil && HandleForward(h.WebForward, w, r, path) {
97 w.Handler = h.Name
98 return true
99 }
100 if h.WebInternal != nil && HandleInternal(h.WebInternal, w, r) {
101 w.Handler = h.Name
102 return true
103 }
104 }
105 w.Compress = false
106 return false
107}
108
109var lsTemplate = htmltemplate.Must(htmltemplate.New("ls").Parse(`<!doctype html>
110<html>
111 <head>
112 <meta charset="utf-8" />
113 <meta name="viewport" content="width=device-width, initial-scale=1" />
114 <title>ls</title>
115 <style>
116body, html { padding: 1em; font-size: 16px; }
117* { font-size: inherit; font-family: ubuntu, lato, sans-serif; margin: 0; padding: 0; box-sizing: border-box; }
118h1 { margin-bottom: 1ex; font-size: 1.2rem; }
119table td, table th { padding: .2em .5em; }
120table > tbody > tr:nth-child(odd) { background-color: #f8f8f8; }
121[title] { text-decoration: underline; text-decoration-style: dotted; }
122 </style>
123 </head>
124 <body>
125 <h1>ls</h1>
126 <table>
127 <thead>
128 <tr>
129 <th>Size in MB</th>
130 <th>Modified (UTC)</th>
131 <th>Name</th>
132 </tr>
133 </thead>
134 <tbody>
135 {{ if not .Files }}
136 <tr><td colspan="3">No files.</td></tr>
137 {{ end }}
138 {{ range .Files }}
139 <tr>
140 <td title="{{ .Size }} bytes" style="text-align: right">{{ .SizeReadable }}{{ if .SizePad }}<span style="visibility:hidden">.  </span>{{ end }}</td>
141 <td>{{ .Modified }}</td>
142 <td><a style="display: block" href="{{ .Name }}">{{ .Name }}</a></td>
143 </tr>
144 {{ end }}
145 </tbody>
146 </table>
147 </body>
148</html>
149`))
150
151// HandleStatic serves static files. If a directory is requested and the URL
152// path doesn't end with a slash, a response with a redirect to the URL path with trailing
153// slash is written. If a directory is requested and an index.html exists, that
154// file is returned. Otherwise, for directories with ListFiles configured, a
155// directory listing is returned.
156func HandleStatic(h *config.WebStatic, compress bool, w http.ResponseWriter, r *http.Request) (handled bool) {
157 log := func() mlog.Log {
158 return pkglog.WithContext(r.Context())
159 }
160 if r.Method != "GET" && r.Method != "HEAD" {
161 if h.ContinueNotFound {
162 // Give another handler that is presumbly configured, for the same path, a chance.
163 // E.g. an app that may generate this file for future requests to pick up.
164 return false
165 }
166 http.Error(w, "405 - method not allowed", http.StatusMethodNotAllowed)
167 return true
168 }
169
170 var fspath string
171 if h.StripPrefix != "" {
172 if !strings.HasPrefix(r.URL.Path, h.StripPrefix) {
173 if h.ContinueNotFound {
174 // We haven't handled this request, try a next WebHandler in the list.
175 return false
176 }
177 http.NotFound(w, r)
178 return true
179 }
180 fspath = filepath.Join(h.Root, strings.TrimPrefix(r.URL.Path, h.StripPrefix))
181 } else {
182 fspath = filepath.Join(h.Root, r.URL.Path)
183 }
184 // fspath will not have a trailing slash anymore, we'll correct for it
185 // later when the path turns out to be file instead of a directory.
186
187 serveFile := func(name string, fi fs.FileInfo, content *os.File) {
188 // ServeContent only sets a content-type if not already present in the response headers.
189 hdr := w.Header()
190 for k, v := range h.ResponseHeaders {
191 hdr.Add(k, v)
192 }
193 // We transparently compress here, but still use ServeContent, because it handles
194 // conditional requests, range requests. It's a bit of a hack, but on first write
195 // to staticgzcacheReplacer where we are compressing, we write the full compressed
196 // file instead, and return an error to ServeContent so it stops. We still have all
197 // the useful behaviour (status code and headers) from ServeContent.
198 xw := w
199 if compress && acceptsGzip(r) && compressibleContent(content) {
200 xw = &staticgzcacheReplacer{w, r, content.Name(), content, fi.ModTime(), fi.Size(), 0, false}
201 } else {
202 w.(*loggingWriter).Compress = false
203 }
204 http.ServeContent(xw, r, name, fi.ModTime(), content)
205 }
206
207 f, err := os.Open(fspath)
208 if err != nil {
209 if os.IsNotExist(err) || errors.Is(err, syscall.ENOTDIR) || errors.Is(err, syscall.EINVAL) {
210 if h.ContinueNotFound {
211 // We haven't handled this request, try a next WebHandler in the list.
212 return false
213 }
214 http.NotFound(w, r)
215 return true
216 } else if errors.Is(err, syscall.ENAMETOOLONG) {
217 http.NotFound(w, r)
218 return true
219 } else if os.IsPermission(err) {
220 // If we tried opening a directory, we may not have permission to read it, but
221 // still access files inside it (execute bit), such as index.html. So try to serve it.
222 index, err := os.Open(filepath.Join(fspath, "index.html"))
223 if err != nil {
224 http.Error(w, "403 - permission denied", http.StatusForbidden)
225 return true
226 }
227 defer func() {
228 err := index.Close()
229 log().Check(err, "closing index file for serving")
230 }()
231 var ifi os.FileInfo
232 ifi, err = index.Stat()
233 if err != nil {
234 log().Errorx("stat index.html in directory we cannot list", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
235 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
236 return true
237 }
238 w.Header().Set("Content-Type", "text/html; charset=utf-8")
239 serveFile("index.html", ifi, index)
240 return true
241 }
242 log().Errorx("open file for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
243 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
244 return true
245 }
246 defer func() {
247 if err := f.Close(); err != nil {
248 log().Check(err, "closing file for static file serving")
249 }
250 }()
251
252 fi, err := f.Stat()
253 if err != nil {
254 log().Errorx("stat file for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
255 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
256 return true
257 }
258 // Redirect if the local path is a directory.
259 if fi.IsDir() && !strings.HasSuffix(r.URL.Path, "/") {
260 http.Redirect(w, r, r.URL.Path+"/", http.StatusTemporaryRedirect)
261 return true
262 } else if !fi.IsDir() && strings.HasSuffix(r.URL.Path, "/") {
263 if h.ContinueNotFound {
264 return false
265 }
266 http.NotFound(w, r)
267 return true
268 }
269
270 if fi.IsDir() {
271 index, err := os.Open(filepath.Join(fspath, "index.html"))
272 if err != nil && os.IsPermission(err) {
273 http.Error(w, "403 - permission denied", http.StatusForbidden)
274 return true
275 } else if err != nil && os.IsNotExist(err) && !h.ListFiles {
276 if h.ContinueNotFound {
277 return false
278 }
279 http.Error(w, "403 - permission denied", http.StatusForbidden)
280 return true
281 } else if err == nil {
282 defer func() {
283 if err := index.Close(); err != nil {
284 log().Check(err, "closing index file for serving")
285 }
286 }()
287
288 var ifi os.FileInfo
289 ifi, err = index.Stat()
290 if err == nil {
291 w.Header().Set("Content-Type", "text/html; charset=utf-8")
292 serveFile("index.html", ifi, index)
293 return true
294 }
295 }
296 if !os.IsNotExist(err) {
297 log().Errorx("stat for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
298 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
299 return true
300 }
301
302 type File struct {
303 Name string
304 Size int64
305 SizeReadable string
306 SizePad bool // Whether the size needs padding because it has no decimal point.
307 Modified string
308 }
309 files := []File{}
310 if r.URL.Path != "/" {
311 files = append(files, File{"..", 0, "", false, ""})
312 }
313 for {
314 l, err := f.Readdir(1000)
315 for _, e := range l {
316 mb := float64(e.Size()) / (1024 * 1024)
317 var size string
318 var sizepad bool
319 if !e.IsDir() {
320 if mb >= 10 {
321 size = fmt.Sprintf("%d", int64(mb))
322 sizepad = true
323 } else {
324 size = fmt.Sprintf("%.2f", mb)
325 }
326 }
327 const dateTime = "2006-01-02 15:04:05" // time.DateTime, but only since go1.20.
328 modified := e.ModTime().UTC().Format(dateTime)
329 f := File{e.Name(), e.Size(), size, sizepad, modified}
330 if e.IsDir() {
331 f.Name += "/"
332 }
333 files = append(files, f)
334 }
335 if err == io.EOF {
336 break
337 } else if err != nil {
338 log().Errorx("reading directory for file listing", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
339 http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
340 return true
341 }
342 }
343 sort.Slice(files, func(i, j int) bool {
344 return files[i].Name < files[j].Name
345 })
346 hdr := w.Header()
347 hdr.Set("Content-Type", "text/html; charset=utf-8")
348 for k, v := range h.ResponseHeaders {
349 if !strings.EqualFold(k, "content-type") {
350 hdr.Add(k, v)
351 }
352 }
353 err = lsTemplate.Execute(w, map[string]any{"Files": files})
354 if err != nil {
355 log().Check(err, "executing directory listing template")
356 }
357 return true
358 }
359
360 serveFile(fspath, fi, f)
361 return true
362}
363
364// HandleRedirect writes a response with an HTTP redirect.
365func HandleRedirect(h *config.WebRedirect, w http.ResponseWriter, r *http.Request) (handled bool) {
366 var dstpath string
367 if h.OrigPath == nil {
368 // No path rewrite necessary.
369 dstpath = r.URL.Path
370 } else if !h.OrigPath.MatchString(r.URL.Path) {
371 http.NotFound(w, r)
372 return true
373 } else {
374 dstpath = h.OrigPath.ReplaceAllString(r.URL.Path, h.ReplacePath)
375 }
376
377 u := *r.URL
378 u.Opaque = ""
379 u.RawPath = ""
380 u.OmitHost = false
381 if h.URL != nil {
382 u.Scheme = h.URL.Scheme
383 u.Host = h.URL.Host
384 u.ForceQuery = h.URL.ForceQuery
385 u.RawQuery = h.URL.RawQuery
386 u.Fragment = h.URL.Fragment
387 if r.URL.RawQuery != "" {
388 if u.RawQuery != "" {
389 u.RawQuery += "&"
390 }
391 u.RawQuery += r.URL.RawQuery
392 }
393 }
394 u.Path = dstpath
395 code := http.StatusPermanentRedirect
396 if h.StatusCode != 0 {
397 code = h.StatusCode
398 }
399
400 // If we would be redirecting to the same scheme,host,path, we would get here again
401 // causing a redirect loop. Instead, this causes this redirect to not match,
402 // allowing to try the next WebHandler. This can be used to redirect all plain http
403 // requests to https.
404 reqscheme := "http"
405 if r.TLS != nil {
406 reqscheme = "https"
407 }
408 if reqscheme == u.Scheme && r.Host == u.Host && r.URL.Path == u.Path {
409 return false
410 }
411
412 http.Redirect(w, r, u.String(), code)
413 return true
414}
415
416// HandleInternal passes the request to an internal service.
417func HandleInternal(h *config.WebInternal, w http.ResponseWriter, r *http.Request) (handled bool) {
418 h.Handler.ServeHTTP(w, r)
419 return true
420}
421
422// HandleForward handles a request by forwarding it to another webserver and
423// passing the response on. I.e. a reverse proxy. It handles websocket
424// connections by monitoring the websocket handshake and then just passing along the
425// websocket frames.
426func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
427 log := func() mlog.Log {
428 return pkglog.WithContext(r.Context())
429 }
430
431 xr := *r
432 r = &xr
433 if h.StripPath {
434 u := *r.URL
435 u.Path = r.URL.Path[len(path):]
436 if !strings.HasPrefix(u.Path, "/") && u.Path != "" {
437 u.Path = "/" + u.Path
438 }
439 u.RawPath = u.EscapedPath()
440 r.URL = &u
441 }
442
443 // Remove any forwarded headers passed in by client.
444 hdr := http.Header{}
445 for k, vl := range r.Header {
446 if k == "Forwarded" || k == "X-Forwarded" || strings.HasPrefix(k, "X-Forwarded-") {
447 continue
448 }
449 hdr[k] = vl
450 }
451 r.Header = hdr
452
453 // Add our own X-Forwarded headers. ReverseProxy will add X-Forwarded-For.
454 r.Header["X-Forwarded-Host"] = []string{r.Host}
455 proto := "http"
456 if r.TLS != nil {
457 proto = "https"
458 }
459 r.Header["X-Forwarded-Proto"] = []string{proto}
460 // note: We are not using "ws" or "wss" for websocket. The request we are
461 // forwarding is http(s), and we don't yet know if the backend even supports
462 // websockets.
463
464 // todo: add Forwarded header? is anyone using it?
465
466 // If we see an Upgrade: websocket, we're going to assume the client needs
467 // websocket and only attempt to talk websocket with the backend. If the backend
468 // doesn't do websocket, we'll send back a "bad request" response. For other values
469 // of Upgrade, we don't do anything special.
470 // https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml
471 // Upgrade: ../rfc/9110:2798
472 // Upgrade headers are not for http/1.0, ../rfc/9110:2880
473 // Websocket client "handshake" is described at ../rfc/6455:1134
474 upgrade := r.Header.Get("Upgrade")
475 if upgrade != "" && !(r.ProtoMajor == 1 && r.ProtoMinor == 0) {
476 // Websockets have case-insensitive string "websocket".
477 for s := range strings.SplitSeq(upgrade, ",") {
478 if strings.EqualFold(textproto.TrimString(s), "websocket") {
479 forwardWebsocket(h, w, r, path)
480 return true
481 }
482 }
483 }
484
485 // ReverseProxy will append any remaining path to the configured target URL.
486 proxy := newSingleHostReverseProxy(h.TargetURL)
487 proxy.FlushInterval = time.Duration(-1) // Flush after each write.
488 proxy.ErrorLog = golog.New(mlog.LogWriter(mlog.New("net/http/httputil", nil).WithContext(r.Context()), mlog.LevelDebug, "reverseproxy error"), "", 0)
489 proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
490 if errors.Is(err, context.Canceled) {
491 log().Debugx("forwarding request to backend webserver", err, slog.Any("url", r.URL))
492 return
493 }
494 log().Errorx("forwarding request to backend webserver", err, slog.Any("url", r.URL))
495 if os.IsTimeout(err) {
496 http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
497 } else {
498 http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
499 }
500 }
501 whdr := w.Header()
502 for k, v := range h.ResponseHeaders {
503 whdr.Add(k, v)
504 }
505 proxy.ServeHTTP(w, r)
506 return true
507}
508
509var errResponseNotWebsocket = errors.New("not a valid websocket response to request")
510var errNotImplemented = errors.New("functionality not yet implemented")
511
512// Request has an Upgrade: websocket header. Check more websocketiness about the
513// request. If it looks good, we forward it to the backend. If the backend responds
514// with a valid websocket response, indicating it is indeed a websocket server, we
515// pass the response along and start copying data between the client and the
516// backend. We don't look at the frames and payloads. The backend already needs to
517// know enough websocket to handle the frames. It wouldn't necessarily hurt to
518// monitor the frames too, and check if they are valid, but it's quite a bit of
519// work for little benefit. Besides, the whole point of websockets is to exchange
520// bytes without HTTP being in the way, so let's do that.
521func forwardWebsocket(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
522 log := func() mlog.Log {
523 return pkglog.WithContext(r.Context())
524 }
525
526 lw := w.(*loggingWriter)
527 lw.WebsocketRequest = true // For correct protocol in metrics.
528
529 // We check the requested websocket version first. A future websocket version may
530 // have different request requirements.
531 // ../rfc/6455:1160
532 wsversion := r.Header.Get("Sec-WebSocket-Version")
533 if wsversion != "13" {
534 // Indicate we only support version 13. Should get a client from the future to fall back to version 13.
535 // ../rfc/6455:1435
536 w.Header().Set("Sec-WebSocket-Version", "13")
537 http.Error(w, "400 - bad request - websockets only supported with version 13"+recvid(r), http.StatusBadRequest)
538 lw.error(fmt.Errorf("Sec-WebSocket-Version %q not supported", wsversion))
539 return true
540 }
541
542 // ../rfc/6455:1143
543 if r.Method != "GET" {
544 http.Error(w, "400 - bad request - websockets only allowed with method GET"+recvid(r), http.StatusBadRequest)
545 lw.error(fmt.Errorf("websocket request only allowed with method GET"))
546 return true
547 }
548
549 // ../rfc/6455:1153
550 var connectionUpgrade bool
551 for s := range strings.SplitSeq(r.Header.Get("Connection"), ",") {
552 if strings.EqualFold(textproto.TrimString(s), "upgrade") {
553 connectionUpgrade = true
554 break
555 }
556 }
557 if !connectionUpgrade {
558 http.Error(w, "400 - bad request - connection header must be \"upgrade\""+recvid(r), http.StatusBadRequest)
559 lw.error(fmt.Errorf(`connection header is %q, must be "upgrade"`, r.Header.Get("Connection")))
560 return true
561 }
562
563 // ../rfc/6455:1156
564 wskey := r.Header.Get("Sec-WebSocket-Key")
565 key, err := base64.StdEncoding.DecodeString(wskey)
566 if err != nil || len(key) != 16 {
567 http.Error(w, "400 - bad request - websockets requires Sec-WebSocket-Key with 16 bytes base64-encoded value"+recvid(r), http.StatusBadRequest)
568 lw.error(fmt.Errorf("bad Sec-WebSocket-Key %q, must be 16 byte base64-encoded value", wskey))
569 return true
570 }
571
572 // ../rfc/6455:1162
573 // We don't look at the origin header. The backend needs to handle it, if it thinks
574 // that helps...
575 // We also don't look at Sec-WebSocket-Protocol and Sec-WebSocket-Extensions. The
576 // backend can set them, but it doesn't influence our forwarding of the data.
577
578 // If this is not a hijacker, there is not point in connecting to the backend.
579 hj, ok := lw.W.(http.Hijacker)
580 var cbr *bufio.ReadWriter
581 if !ok {
582 log().Info("cannot turn http connection into tcp connection (http.Hijacker)")
583 http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
584 lw.error(fmt.Errorf("connection not a http.Hijacker (%T)", lw.W))
585 return
586 }
587
588 freq := *r
589 freq.Proto = "HTTP/1.1"
590 freq.ProtoMajor = 1
591 freq.ProtoMinor = 1
592 fresp, beconn, err := websocketTransact(r.Context(), h.TargetURL, &freq)
593 if err != nil {
594 if errors.Is(err, errResponseNotWebsocket) {
595 http.Error(w, "400 - bad request - websocket not supported"+recvid(r), http.StatusBadRequest)
596 } else if errors.Is(err, errNotImplemented) {
597 http.Error(w, "501 - not implemented - "+err.Error()+recvid(r), http.StatusNotImplemented)
598 } else if os.IsTimeout(err) {
599 http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
600 } else {
601 http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
602 }
603 lw.error(err)
604 return
605 }
606 defer func() {
607 if beconn != nil {
608 if err := beconn.Close(); err != nil {
609 log().Check(err, "closing backend websocket connection")
610 }
611 }
612 }()
613
614 // Hijack the client connection so we can write the response ourselves, and start
615 // copying the websocket frames.
616 var cconn net.Conn
617 cconn, cbr, err = hj.Hijack()
618 if err != nil {
619 log().Debugx("cannot turn http transaction into websocket connection", err)
620 http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
621 lw.error(err)
622 return
623 }
624 defer func() {
625 if cconn != nil {
626 if err := cconn.Close(); err != nil {
627 log().Check(err, "closing client websocket connection")
628 }
629 }
630 }()
631
632 // Below this point, we can no longer write to the ResponseWriter.
633
634 // Mark as websocket response, for logging.
635 lw.WebsocketResponse = true
636 lw.setStatusCode(fresp.StatusCode)
637
638 for k, v := range h.ResponseHeaders {
639 fresp.Header.Add(k, v)
640 }
641
642 // Write the response to the client, completing its websocket handshake.
643 if err := fresp.Write(cconn); err != nil {
644 lw.error(fmt.Errorf("writing websocket response to client: %w", err))
645 return
646 }
647
648 errc := make(chan error, 1)
649
650 // Copy from client to backend.
651 go func() {
652 buf, err := cbr.Peek(cbr.Reader.Buffered())
653 if err != nil {
654 errc <- err
655 return
656 }
657 if len(buf) > 0 {
658 n, err := beconn.Write(buf)
659 if err != nil {
660 errc <- err
661 return
662 }
663 lw.SizeFromClient += int64(n)
664 }
665 n, err := io.Copy(beconn, cconn)
666 lw.SizeFromClient += n
667 errc <- err
668 }()
669
670 // Copy from backend to client.
671 go func() {
672 n, err := io.Copy(cconn, beconn)
673 lw.SizeToClient = n
674 errc <- err
675 }()
676
677 // Stop and close connection on first error from either size, typically a closed
678 // connection whose closing was already announced with a websocket frame.
679 lw.error(<-errc)
680 // Close connections so other goroutine stops as well.
681 if err := cconn.Close(); err != nil {
682 log().Check(err, "closing client websocket connection")
683 }
684 if err := beconn.Close(); err != nil {
685 log().Check(err, "closing backend websocket connection")
686 }
687 // Wait for goroutine so it has updated the logWriter.Size*Client fields before we
688 // continue with logging.
689 <-errc
690 cconn = nil
691 return true
692}
693
694func websocketTransact(ctx context.Context, targetURL *url.URL, r *http.Request) (rresp *http.Response, rconn net.Conn, rerr error) {
695 log := func() mlog.Log {
696 return pkglog.WithContext(r.Context())
697 }
698
699 // Dial the backend, possibly doing TLS. We assume the net/http DefaultTransport is
700 // unmodified.
701 transport := http.DefaultTransport.(*http.Transport)
702
703 // We haven't implemented using a proxy for websocket requests yet. If we need one,
704 // return an error instead of trying to connect directly, which would be a
705 // potential security issue.
706 treq := *r
707 treq.URL = targetURL
708 if purl, err := transport.Proxy(&treq); err != nil {
709 return nil, nil, fmt.Errorf("determining proxy for websocket backend connection: %w", err)
710 } else if purl != nil {
711 return nil, nil, fmt.Errorf("%w: proxy required for websocket connection to backend", errNotImplemented) // todo: implement?
712 }
713
714 host, port, err := net.SplitHostPort(targetURL.Host)
715 if err != nil {
716 host = targetURL.Host
717 if targetURL.Scheme == "https" {
718 port = "443"
719 } else {
720 port = "80"
721 }
722 }
723 addr := net.JoinHostPort(host, port)
724 conn, err := transport.DialContext(r.Context(), "tcp", addr)
725 if err != nil {
726 return nil, nil, fmt.Errorf("dial: %w", err)
727 }
728 if targetURL.Scheme == "https" {
729 tlsconn := tls.Client(conn, transport.TLSClientConfig)
730 ctx, cancel := context.WithTimeout(r.Context(), transport.TLSHandshakeTimeout)
731 defer cancel()
732 if err := tlsconn.HandshakeContext(ctx); err != nil {
733 return nil, nil, fmt.Errorf("tls handshake: %w", err)
734 }
735 conn = tlsconn
736 }
737 defer func() {
738 if rerr != nil {
739 if xerr := conn.Close(); xerr != nil {
740 log().Check(xerr, "cleaning up websocket connection")
741 }
742 }
743 }()
744
745 // todo: make timeout configurable?
746 if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
747 log().Check(err, "set deadline for websocket request to backend")
748 }
749
750 // Set clean connection headers.
751 removeHopByHopHeaders(r.Header)
752 r.Header.Set("Connection", "Upgrade")
753 r.Header.Set("Upgrade", "websocket")
754
755 // Write the websocket request to the backend.
756 if err := r.Write(conn); err != nil {
757 return nil, nil, fmt.Errorf("writing request to backend: %w", err)
758 }
759
760 // Read response from backend.
761 br := bufio.NewReader(conn)
762 resp, err := http.ReadResponse(br, r)
763 if err != nil {
764 return nil, nil, fmt.Errorf("reading response from backend: %w", err)
765 }
766 defer func() {
767 if rerr != nil {
768 if xerr := resp.Body.Close(); xerr != nil {
769 log().Check(xerr, "closing response body after error")
770 }
771 }
772 }()
773 if err := conn.SetDeadline(time.Time{}); err != nil {
774 log().Check(err, "clearing deadline on websocket connection to backend")
775 }
776
777 // Check that the response from the backend server indicates it is websocket. If
778 // not, don't pass the backend response, but an error that websocket is not
779 // appropriate.
780 if err := checkWebsocketResponse(resp, r); err != nil {
781 return resp, nil, err
782 }
783
784 // note: net/http.Response.Body documents that it implements io.Writer for a
785 // status: 101 response. But that's not the case when the response has been read
786 // with http.ReadResponse. We'll write to the connection directly.
787
788 buf, err := br.Peek(br.Buffered())
789 if err != nil {
790 return resp, nil, fmt.Errorf("peek at buffered data written by backend: %w", err)
791 }
792 return resp, websocketConn{io.MultiReader(bytes.NewReader(buf), conn), conn}, nil
793}
794
795// A net.Conn but with reads coming from an io multireader (due to buffered reader
796// needed for http.ReadResponse).
797type websocketConn struct {
798 r io.Reader
799 net.Conn
800}
801
802func (c websocketConn) Read(buf []byte) (int, error) {
803 return c.r.Read(buf)
804}
805
806// Check that an HTTP response (from a backend) is a valid websocket response, i.e.
807// that it accepts the WebSocket "upgrade".
808// ../rfc/6455:1299
809func checkWebsocketResponse(resp *http.Response, req *http.Request) error {
810 if resp.StatusCode != 101 {
811 return fmt.Errorf("%w: response http status not 101 but %s", errResponseNotWebsocket, resp.Status)
812 }
813 if upgrade := resp.Header.Get("Upgrade"); !strings.EqualFold(upgrade, "websocket") {
814 return fmt.Errorf(`%w: response http status is 101, but Upgrade header is %q, should be "websocket"`, errResponseNotWebsocket, upgrade)
815 }
816 if connection := resp.Header.Get("Connection"); !strings.EqualFold(connection, "upgrade") {
817 return fmt.Errorf(`%w: response http status is 101, Upgrade is websocket, but Connection header is %q, should be "Upgrade"`, errResponseNotWebsocket, connection)
818 }
819 accept, err := base64.StdEncoding.DecodeString(resp.Header.Get("Sec-WebSocket-Accept"))
820 if err != nil {
821 return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but Sec-WebSocket-Accept header is not valid base64: %v`, errResponseNotWebsocket, err)
822 }
823 exp := sha1.Sum([]byte(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
824 if !bytes.Equal(accept, exp[:]) {
825 return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but backend Sec-WebSocket-Accept value does not match`, errResponseNotWebsocket)
826 }
827 // We don't have requirements for the other Sec-WebSocket headers. ../rfc/6455:1340
828 return nil
829}
830
831// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
832// Hop-by-hop headers. These are removed when sent to the backend.
833// As of RFC 7230, hop-by-hop headers are required to appear in the
834// Connection header field. These are the headers defined by the
835// obsoleted RFC 2616 (section 13.5.1) and are used for backward
836// compatibility.
837// ../rfc/2616:5128
838var hopHeaders = []string{
839 "Connection",
840 "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
841 "Keep-Alive",
842 "Proxy-Authenticate",
843 "Proxy-Authorization",
844 "Te", // canonicalized version of "TE"
845 "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
846 "Transfer-Encoding",
847 "Upgrade",
848}
849
850// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
851// removeHopByHopHeaders removes hop-by-hop headers.
852func removeHopByHopHeaders(h http.Header) {
853 // RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
854 // ../rfc/7230:2817
855 for _, f := range h["Connection"] {
856 for sf := range strings.SplitSeq(f, ",") {
857 if sf = textproto.TrimString(sf); sf != "" {
858 h.Del(sf)
859 }
860 }
861 }
862 // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
863 // This behavior is superseded by the RFC 7230 Connection header, but
864 // preserve it for backwards compatibility.
865 // ../rfc/2616:5128
866 for _, f := range hopHeaders {
867 h.Del(f)
868 }
869}
870