1// Package mtasts implements MTA-STS (SMTP MTA Strict Transport Security, RFC 8461)
 
2// which allows a domain to specify SMTP TLS requirements.
 
4// SMTP for message delivery to a remote mail server always starts out unencrypted,
 
5// in plain text. STARTTLS allows upgrading the connection to TLS, but is optional
 
6// and by default mail servers will fall back to plain text communication if
 
7// STARTTLS does not work (which can be sabotaged by DNS manipulation or SMTP
 
8// connection manipulation). MTA-STS can specify a policy for requiring STARTTLS to
 
9// be used for message delivery. A TXT DNS record at "_mta-sts.<domain>" specifies
 
10// the version of the policy, and
 
11// "https://mta-sts.<domain>/.well-known/mta-sts.txt" serves the policy.
 
23	"github.com/prometheus/client_golang/prometheus"
 
24	"github.com/prometheus/client_golang/prometheus/promauto"
 
26	"github.com/mjl-/adns"
 
28	"github.com/mjl-/mox/dns"
 
29	"github.com/mjl-/mox/metrics"
 
30	"github.com/mjl-/mox/mlog"
 
31	"github.com/mjl-/mox/moxio"
 
34var xlog = mlog.New("mtasts")
 
37	metricGet = promauto.NewHistogramVec(
 
38		prometheus.HistogramOpts{
 
39			Name:    "mox_mtasts_get_duration_seconds",
 
40			Help:    "MTA-STS get of policy, including lookup, duration and result.",
 
41			Buckets: []float64{0.01, 0.05, 0.100, 0.5, 1, 5, 10, 20},
 
44			"result", // ok, lookuperror, fetcherror
 
49// Pair is an extension key/value pair in a MTA-STS DNS record or policy.
 
55// Record is an MTA-STS DNS record, served under "_mta-sts.<domain>" as a TXT
 
60//	v=STSv1; id=20160831085700Z
 
62	Version    string // "STSv1", for "v=". Required.
 
63	ID         string // Record version, for "id=". Required.
 
64	Extensions []Pair // Optional extensions.
 
67// String returns a textual version of the MTA-STS record for use as DNS TXT
 
69func (r Record) String() string {
 
70	b := &strings.Builder{}
 
71	fmt.Fprint(b, "v="+r.Version)
 
72	fmt.Fprint(b, "; id="+r.ID)
 
73	for _, p := range r.Extensions {
 
74		fmt.Fprint(b, "; "+p.Key+"="+p.Value)
 
79// Mode indicates how the policy should be interpreted.
 
85	ModeEnforce Mode = "enforce" // Policy must be followed, i.e. deliveries must fail if a TLS connection cannot be made.
 
86	ModeTesting Mode = "testing" // In case TLS cannot be negotiated, plain SMTP can be used, but failures must be reported, e.g. with TLS-RPT.
 
87	ModeNone    Mode = "none"    // In case MTA-STS is not or no longer implemented.
 
90// STSMX is an allowlisted MX host name/pattern.
 
91// todo: find a way to name this just STSMX without getting duplicate names for "MX" in the sherpa api.
 
93	// "*." wildcard, e.g. if a subdomain matches. A wildcard must match exactly one
 
94	// label. *.example.com matches mail.example.com, but not example.com, and not
 
95	// foor.bar.example.com.
 
101// LogString returns a loggable string representing the host, with both unicode
 
102// and ascii version for IDNA domains.
 
103func (s STSMX) LogString() string {
 
108	if s.Domain.Unicode == "" {
 
109		return pre + s.Domain.ASCII
 
111	return pre + s.Domain.Unicode + "/" + pre + s.Domain.ASCII
 
114// Policy is an MTA-STS policy as served at "https://mta-sts.<domain>/.well-known/mta-sts.txt".
 
116	Version       string // "STSv1"
 
119	MaxAgeSeconds int // How long this policy can be cached. Suggested values are in weeks or more.
 
123// String returns a textual representation for serving at the well-known URL.
 
124func (p Policy) String() string {
 
125	b := &strings.Builder{}
 
126	line := func(k, v string) {
 
127		fmt.Fprint(b, k+": "+v+"\n")
 
129	line("version", p.Version)
 
130	line("mode", string(p.Mode))
 
131	line("max_age", fmt.Sprintf("%d", p.MaxAgeSeconds))
 
132	for _, mx := range p.MX {
 
133		s := mx.Domain.Name()
 
142// Matches returns whether the hostname matches the mx list in the policy.
 
143func (p *Policy) Matches(host dns.Domain) bool {
 
145	for _, mx := range p.MX {
 
147			v := strings.SplitN(host.ASCII, ".", 2)
 
148			if len(v) == 2 && v[1] == mx.Domain.ASCII {
 
151		} else if host == mx.Domain {
 
158// TLSReportFailureReason returns a concise error for known error types, or an
 
159// empty string. For use in TLSRPT.
 
160func TLSReportFailureReason(err error) string {
 
161	// If this is a DNSSEC authentication error, we'll collect it for TLS reporting.
 
163	var errCode adns.ErrorCode
 
164	if errors.As(err, &errCode) && errCode.IsAuthentication() {
 
165		return fmt.Sprintf("dns-extended-error-%d-%s", errCode, strings.ReplaceAll(errCode.String(), " ", "-"))
 
168	for _, e := range mtastsErrors {
 
169		if errors.Is(err, e) {
 
170			s := strings.TrimPrefix(e.Error(), "mtasts: ")
 
171			return strings.ReplaceAll(s, " ", "-")
 
177var mtastsErrors = []error{
 
178	ErrNoRecord, ErrMultipleRecords, ErrDNS, ErrRecordSyntax, // Lookup
 
179	ErrNoPolicy, ErrPolicyFetch, ErrPolicySyntax, // Fetch
 
184	ErrNoRecord        = errors.New("mtasts: no mta-sts dns txt record") // Domain does not implement MTA-STS. If a cached non-expired policy is available, it should still be used.
 
185	ErrMultipleRecords = errors.New("mtasts: multiple mta-sts records")  // Should be treated as if domain does not implement MTA-STS, unless a cached non-expired policy is available.
 
186	ErrDNS             = errors.New("mtasts: dns lookup")                // For temporary DNS errors.
 
187	ErrRecordSyntax    = errors.New("mtasts: record syntax error")
 
190// LookupRecord looks up the MTA-STS TXT DNS record at "_mta-sts.<domain>",
 
191// following CNAME records, and returns the parsed MTA-STS record and the DNS TXT
 
193func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (rrecord *Record, rtxt string, rerr error) {
 
194	log := xlog.WithContext(ctx)
 
197		log.Debugx("mtasts lookup result", rerr, mlog.Field("domain", domain), mlog.Field("record", rrecord), mlog.Field("duration", time.Since(start)))
 
202	// We lookup the txt record, but must follow CNAME records when the TXT does not
 
203	// exist. LookupTXT follows CNAMEs.
 
204	name := "_mta-sts." + domain.ASCII + "."
 
206	txts, _, err := dns.WithPackage(resolver, "mtasts").LookupTXT(ctx, name)
 
207	if dns.IsNotFound(err) {
 
208		return nil, "", ErrNoRecord
 
209	} else if err != nil {
 
210		return nil, "", fmt.Errorf("%w: %s", ErrDNS, err)
 
215	for _, txt := range txts {
 
216		r, ismtasts, err := ParseRecord(txt)
 
219			// "v=STSv1 ;" (note the space) as a non-STS record too in case of multiple TXT
 
220			// records. We treat it as an STS record that is invalid, which is possibly more
 
228			return nil, "", ErrMultipleRecords
 
234		return nil, "", ErrNoRecord
 
236	return record, text, nil
 
239// Policy fetch errors.
 
241	ErrNoPolicy     = errors.New("mtasts: no policy served")    // If the name "mta-sts.<domain>" does not exist in DNS or if webserver returns HTTP status 404 "File not found".
 
242	ErrPolicyFetch  = errors.New("mtasts: cannot fetch policy") // E.g. for HTTP request errors.
 
243	ErrPolicySyntax = errors.New("mtasts: policy syntax error")
 
246// HTTPClient is used by FetchPolicy for HTTP requests.
 
247var HTTPClient = &http.Client{
 
248	CheckRedirect: func(req *http.Request, via []*http.Request) error {
 
253// FetchPolicy fetches a new policy for the domain, at
 
254// https://mta-sts.<domain>/.well-known/mta-sts.txt.
 
256// FetchPolicy returns the parsed policy and the literal policy text as fetched
 
257// from the server. If a policy was fetched but could not be parsed, the policyText
 
258// return value will be set.
 
260// Policies longer than 64KB result in a syntax error.
 
262// If an error is returned, callers should back off for 5 minutes until the next
 
264func FetchPolicy(ctx context.Context, domain dns.Domain) (policy *Policy, policyText string, rerr error) {
 
265	log := xlog.WithContext(ctx)
 
268		log.Debugx("mtasts fetch policy result", rerr, mlog.Field("domain", domain), mlog.Field("policy", policy), mlog.Field("policytext", policyText), mlog.Field("duration", time.Since(start)))
 
272	ctx, cancel := context.WithTimeout(ctx, time.Minute)
 
275	// TLS requirements are what the Go standard library checks: trusted, non-expired,
 
277	url := "https://mta-sts." + domain.Name() + "/.well-known/mta-sts.txt"
 
278	req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
 
280		return nil, "", fmt.Errorf("%w: http request: %s", ErrPolicyFetch, err)
 
282	// We are not likely to reuse a connection: we cache policies and negative DNS
 
283	// responses. So don't keep connections open unnecessarily.
 
286	resp, err := HTTPClient.Do(req)
 
287	if dns.IsNotFound(err) {
 
288		return nil, "", ErrNoPolicy
 
291		// We pass along underlying TLS certificate errors.
 
292		return nil, "", fmt.Errorf("%w: http get: %w", ErrPolicyFetch, err)
 
294	metrics.HTTPClientObserve(ctx, "mtasts", req.Method, resp.StatusCode, err, start)
 
295	defer resp.Body.Close()
 
296	if resp.StatusCode == http.StatusNotFound {
 
297		return nil, "", ErrNoPolicy
 
299	if resp.StatusCode != http.StatusOK {
 
301		return nil, "", fmt.Errorf("%w: http status %s while status 200 is required", ErrPolicyFetch, resp.Status)
 
304	// We don't look at Content-Type and charset. It should be ASCII or UTF-8, we'll
 
308	buf, err := io.ReadAll(&moxio.LimitReader{R: resp.Body, Limit: 64 * 1024})
 
310		return nil, "", fmt.Errorf("%w: reading policy: %s", ErrPolicySyntax, err)
 
312	policyText = string(buf)
 
313	policy, err = ParsePolicy(policyText)
 
315		return nil, policyText, fmt.Errorf("parsing policy: %w", err)
 
317	return policy, policyText, nil
 
320// Get looks up the MTA-STS DNS record and fetches the policy.
 
322// Errors can be those returned by LookupRecord and FetchPolicy.
 
324// If a valid policy cannot be retrieved, a sender must treat the domain as not
 
325// implementing MTA-STS. If a sender has a non-expired cached policy, that policy
 
328// If a record was retrieved, but a policy could not be retrieved/parsed, the
 
329// record is still returned.
 
331// Also see Get in package mtastsdb.
 
332func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (record *Record, policy *Policy, policyText string, err error) {
 
333	log := xlog.WithContext(ctx)
 
335	result := "lookuperror"
 
337		metricGet.WithLabelValues(result).Observe(float64(time.Since(start)) / float64(time.Second))
 
338		log.Debugx("mtasts get result", err, mlog.Field("domain", domain), mlog.Field("record", record), mlog.Field("policy", policy), mlog.Field("duration", time.Since(start)))
 
341	record, _, err = LookupRecord(ctx, resolver, domain)
 
343		return nil, nil, "", err
 
346	result = "fetcherror"
 
347	policy, policyText, err = FetchPolicy(ctx, domain)
 
349		return record, nil, "", err
 
353	return record, policy, policyText, nil