1package main
2
3/*
4note: these testdata paths are not in the repo, you should gather some of your
5own ham/spam emails.
6
7./mox junk train testdata/train/ham testdata/train/spam
8./mox junk train -sent-dir testdata/sent testdata/train/ham testdata/train/spam
9./mox junk check 'testdata/check/ham/mail1'
10./mox junk test testdata/check/ham testdata/check/spam
11./mox junk analyze testdata/train/ham testdata/train/spam
12./mox junk analyze -top-words 10 -train-ratio 0.5 -spam-threshold 0.85 -max-power 0.01 -sent-dir testdata/sent testdata/train/ham testdata/train/spam
13./mox junk play -top-words 10 -train-ratio 0.5 -spam-threshold 0.85 -max-power 0.01 -sent-dir testdata/sent testdata/train/ham testdata/train/spam
14*/
15
16import (
17 "context"
18 "flag"
19 "fmt"
20 "log"
21 mathrand "math/rand"
22 "os"
23 "sort"
24 "time"
25
26 "github.com/mjl-/mox/junk"
27 "github.com/mjl-/mox/message"
28 "github.com/mjl-/mox/mlog"
29 "github.com/mjl-/mox/mox-"
30)
31
32type junkArgs struct {
33 params junk.Params
34 spamThreshold float64
35 trainRatio float64
36 seed bool
37 sentDir string
38 databasePath, bloomfilterPath string
39 debug bool
40}
41
42func (a junkArgs) SetLogLevel() {
43 mox.Conf.Log[""] = mlog.LevelInfo
44 if a.debug {
45 mox.Conf.Log[""] = mlog.LevelDebug
46 }
47 mlog.SetConfig(mox.Conf.Log)
48}
49
50func junkFlags(fs *flag.FlagSet) (a junkArgs) {
51 fs.BoolVar(&a.params.Onegrams, "one-grams", false, "use 1-grams, i.e. single words, for scoring")
52 fs.BoolVar(&a.params.Twograms, "two-grams", true, "use 2-grams, i.e. word pairs, for scoring")
53 fs.BoolVar(&a.params.Threegrams, "three-grams", false, "use 3-grams, i.e. word triplets, for scoring")
54 fs.Float64Var(&a.params.MaxPower, "max-power", 0.05, "maximum word power, e.g. min 0.05/max 0.95")
55 fs.Float64Var(&a.params.IgnoreWords, "ignore-words", 0.1, "ignore words with ham/spaminess within this distance from 0.5")
56 fs.IntVar(&a.params.TopWords, "top-words", 10, "number of top spam and number of top ham words from email to use")
57 fs.IntVar(&a.params.RareWords, "rare-words", 1, "words are rare if encountered this number during training, and skipped for scoring")
58 fs.BoolVar(&a.debug, "debug", false, "print debug logging when calculating spam probability")
59
60 fs.Float64Var(&a.spamThreshold, "spam-threshold", 0.95, "probability where message is seen as spam")
61 fs.Float64Var(&a.trainRatio, "train-ratio", 0.5, "part of data to use for training versus analyzing (for analyze only)")
62 fs.StringVar(&a.sentDir, "sent-dir", "", "directory with sent mails, for training")
63 fs.BoolVar(&a.seed, "seed", false, "seed prng before analysis")
64 fs.StringVar(&a.databasePath, "dbpath", "filter.db", "database file for ham/spam words")
65 fs.StringVar(&a.bloomfilterPath, "bloompath", "filter.bloom", "bloom filter for ignoring unique strings")
66
67 return
68}
69
70func listDir(dir string) (l []string) {
71 files, err := os.ReadDir(dir)
72 xcheckf(err, "listing directory %q", dir)
73 for _, f := range files {
74 l = append(l, f.Name())
75 }
76 return l
77}
78
79func must(f *junk.Filter, err error) *junk.Filter {
80 xcheckf(err, "filter")
81 return f
82}
83
84func cmdJunkTrain(c *cmd) {
85 c.unlisted = true
86 c.params = "hamdir spamdir"
87 c.help = "Train a junk filter with messages from hamdir and spamdir."
88 a := junkFlags(c.flag)
89 args := c.Parse()
90 if len(args) != 2 {
91 c.Usage()
92 }
93 a.SetLogLevel()
94
95 f := must(junk.NewFilter(context.Background(), mlog.New("junktrain"), a.params, a.databasePath, a.bloomfilterPath))
96 defer func() {
97 if err := f.Close(); err != nil {
98 log.Printf("closing junk filter: %v", err)
99 }
100 }()
101
102 hamFiles := listDir(args[0])
103 spamFiles := listDir(args[1])
104 var sentFiles []string
105 if a.sentDir != "" {
106 sentFiles = listDir(a.sentDir)
107 }
108
109 err := f.TrainDirs(args[0], a.sentDir, args[1], hamFiles, sentFiles, spamFiles)
110 xcheckf(err, "train")
111}
112
113func cmdJunkCheck(c *cmd) {
114 c.unlisted = true
115 c.params = "mailfile"
116 c.help = "Check an email message against a junk filter, printing the probability of spam on a scale from 0 to 1."
117 a := junkFlags(c.flag)
118 args := c.Parse()
119 if len(args) != 1 {
120 c.Usage()
121 }
122 a.SetLogLevel()
123
124 f := must(junk.OpenFilter(context.Background(), mlog.New("junkcheck"), a.params, a.databasePath, a.bloomfilterPath, false))
125 defer func() {
126 if err := f.Close(); err != nil {
127 log.Printf("closing junk filter: %v", err)
128 }
129 }()
130
131 prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), args[0])
132 xcheckf(err, "testing mail")
133
134 fmt.Printf("%.6f\n", prob)
135}
136
137func cmdJunkTest(c *cmd) {
138 c.unlisted = true
139 c.params = "hamdir spamdir"
140 c.help = "Check a directory with hams and one with spams against the junk filter, and report the success ratio."
141 a := junkFlags(c.flag)
142 args := c.Parse()
143 if len(args) != 2 {
144 c.Usage()
145 }
146 a.SetLogLevel()
147
148 f := must(junk.OpenFilter(context.Background(), mlog.New("junktest"), a.params, a.databasePath, a.bloomfilterPath, false))
149 defer func() {
150 if err := f.Close(); err != nil {
151 log.Printf("closing junk filter: %v", err)
152 }
153 }()
154
155 testDir := func(dir string, ham bool) (int, int) {
156 ok, bad := 0, 0
157 files, err := os.ReadDir(dir)
158 xcheckf(err, "readdir %q", dir)
159 for _, fi := range files {
160 path := dir + "/" + fi.Name()
161 prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path)
162 if err != nil {
163 log.Printf("classify message %q: %s", path, err)
164 continue
165 }
166 if ham && prob < a.spamThreshold || !ham && prob > a.spamThreshold {
167 ok++
168 } else {
169 bad++
170 }
171 if ham && prob > a.spamThreshold {
172 fmt.Printf("ham %q: %.4f\n", path, prob)
173 }
174 if !ham && prob < a.spamThreshold {
175 fmt.Printf("spam %q: %.4f\n", path, prob)
176 }
177 }
178 return ok, bad
179 }
180
181 nhamok, nhambad := testDir(args[0], true)
182 nspamok, nspambad := testDir(args[1], false)
183 fmt.Printf("total ham, ok %d, bad %d\n", nhamok, nhambad)
184 fmt.Printf("total spam, ok %d, bad %d\n", nspamok, nspambad)
185 fmt.Printf("specifity (true negatives, hams identified): %.6f\n", float64(nhamok)/(float64(nhamok+nhambad)))
186 fmt.Printf("sensitivity (true positives, spams identified): %.6f\n", float64(nspamok)/(float64(nspamok+nspambad)))
187 fmt.Printf("accuracy: %.6f\n", float64(nhamok+nspamok)/float64(nhamok+nhambad+nspamok+nspambad))
188}
189
190func cmdJunkAnalyze(c *cmd) {
191 c.unlisted = true
192 c.params = "hamdir spamdir"
193 c.help = `Analyze a directory with ham messages and one with spam messages.
194
195A part of the messages is used for training, and remaining for testing. The
196messages are shuffled, with optional random seed.`
197 a := junkFlags(c.flag)
198 args := c.Parse()
199 if len(args) != 2 {
200 c.Usage()
201 }
202 a.SetLogLevel()
203
204 f := must(junk.NewFilter(context.Background(), mlog.New("junkanalyze"), a.params, a.databasePath, a.bloomfilterPath))
205 defer func() {
206 if err := f.Close(); err != nil {
207 log.Printf("closing junk filter: %v", err)
208 }
209 }()
210
211 hamDir := args[0]
212 spamDir := args[1]
213 hamFiles := listDir(hamDir)
214 spamFiles := listDir(spamDir)
215
216 var rand *mathrand.Rand
217 if a.seed {
218 rand = mathrand.New(mathrand.NewSource(time.Now().UnixMilli()))
219 } else {
220 rand = mathrand.New(mathrand.NewSource(0))
221 }
222
223 shuffle := func(l []string) {
224 count := len(l)
225 for i := range l {
226 n := rand.Intn(count)
227 l[i], l[n] = l[n], l[i]
228 }
229 }
230
231 shuffle(hamFiles)
232 shuffle(spamFiles)
233
234 ntrainham := int(a.trainRatio * float64(len(hamFiles)))
235 ntrainspam := int(a.trainRatio * float64(len(spamFiles)))
236
237 trainHam := hamFiles[:ntrainham]
238 trainSpam := spamFiles[:ntrainspam]
239 testHam := hamFiles[ntrainham:]
240 testSpam := spamFiles[ntrainspam:]
241
242 var trainSent []string
243 if a.sentDir != "" {
244 trainSent = listDir(a.sentDir)
245 }
246
247 err := f.TrainDirs(hamDir, a.sentDir, spamDir, trainHam, trainSent, trainSpam)
248 xcheckf(err, "train")
249
250 testDir := func(dir string, files []string, ham bool) (ok, bad, malformed int) {
251 for _, name := range files {
252 path := dir + "/" + name
253 prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path)
254 if err != nil {
255 // log.Infof("%s: %s", path, err)
256 malformed++
257 continue
258 }
259 if ham && prob < a.spamThreshold || !ham && prob > a.spamThreshold {
260 ok++
261 } else {
262 bad++
263 }
264 if ham && prob > a.spamThreshold {
265 fmt.Printf("ham %q: %.4f\n", path, prob)
266 }
267 if !ham && prob < a.spamThreshold {
268 fmt.Printf("spam %q: %.4f\n", path, prob)
269 }
270 }
271 return
272 }
273
274 nhamok, nhambad, nmalformedham := testDir(args[0], testHam, true)
275 nspamok, nspambad, nmalformedspam := testDir(args[1], testSpam, false)
276 fmt.Printf("training done, nham %d, nsent %d, nspam %d\n", ntrainham, len(trainSent), ntrainspam)
277 fmt.Printf("total ham, ok %d, bad %d, malformed %d\n", nhamok, nhambad, nmalformedham)
278 fmt.Printf("total spam, ok %d, bad %d, malformed %d\n", nspamok, nspambad, nmalformedspam)
279 fmt.Printf("specifity (true negatives, hams identified): %.6f\n", float64(nhamok)/(float64(nhamok+nhambad)))
280 fmt.Printf("sensitivity (true positives, spams identified): %.6f\n", float64(nspamok)/(float64(nspamok+nspambad)))
281 fmt.Printf("accuracy: %.6f\n", float64(nhamok+nspamok)/float64(nhamok+nhambad+nspamok+nspambad))
282}
283
284func cmdJunkPlay(c *cmd) {
285 c.unlisted = true
286 c.params = "hamdir spamdir"
287 c.help = "Play messages from ham and spam directory according to their time of arrival and report on junk filter performance."
288 a := junkFlags(c.flag)
289 args := c.Parse()
290 if len(args) != 2 {
291 c.Usage()
292 }
293 a.SetLogLevel()
294
295 f := must(junk.NewFilter(context.Background(), mlog.New("junkplay"), a.params, a.databasePath, a.bloomfilterPath))
296 defer func() {
297 if err := f.Close(); err != nil {
298 log.Printf("closing junk filter: %v", err)
299 }
300 }()
301
302 // We'll go through all emails to find their dates.
303 type msg struct {
304 dir, filename string
305 ham, sent bool
306 t time.Time
307 }
308 var msgs []msg
309
310 var nbad, nnodate, nham, nspam, nsent int
311
312 jlog := mlog.New("junkplay")
313
314 scanDir := func(dir string, ham, sent bool) {
315 for _, name := range listDir(dir) {
316 path := dir + "/" + name
317 mf, err := os.Open(path)
318 xcheckf(err, "open %q", path)
319 fi, err := mf.Stat()
320 xcheckf(err, "stat %q", path)
321 p, err := message.EnsurePart(jlog, false, mf, fi.Size())
322 if err != nil {
323 nbad++
324 if err := mf.Close(); err != nil {
325 log.Printf("closing message file: %v", err)
326 }
327 continue
328 }
329 if p.Envelope.Date.IsZero() {
330 nnodate++
331 if err := mf.Close(); err != nil {
332 log.Printf("closing message file: %v", err)
333 }
334 continue
335 }
336 if err := mf.Close(); err != nil {
337 log.Printf("closing message file: %v", err)
338 }
339 msgs = append(msgs, msg{dir, name, ham, sent, p.Envelope.Date})
340 if sent {
341 nsent++
342 } else if ham {
343 nham++
344 } else {
345 nspam++
346 }
347 }
348 }
349
350 hamDir := args[0]
351 spamDir := args[1]
352 scanDir(hamDir, true, false)
353 scanDir(spamDir, false, false)
354 if a.sentDir != "" {
355 scanDir(a.sentDir, true, true)
356 }
357
358 // Sort the messages, earliest first.
359 sort.Slice(msgs, func(i, j int) bool {
360 return msgs[i].t.Before(msgs[j].t)
361 })
362
363 // Play all messages as if they are coming in. We predict their spaminess, check if
364 // we are right. And we train the system with the result.
365 var nhamok, nhambad, nspamok, nspambad int
366
367 play := func(msg msg) {
368 var words map[string]struct{}
369 path := msg.dir + "/" + msg.filename
370 if !msg.sent {
371 var prob float64
372 var err error
373 prob, words, _, _, err = f.ClassifyMessagePath(context.Background(), path)
374 if err != nil {
375 nbad++
376 return
377 }
378 if msg.ham {
379 if prob < a.spamThreshold {
380 nhamok++
381 } else {
382 nhambad++
383 }
384 } else {
385 if prob > a.spamThreshold {
386 nspamok++
387 } else {
388 nspambad++
389 }
390 }
391 } else {
392 mf, err := os.Open(path)
393 xcheckf(err, "open %q", path)
394 defer func() {
395 if err := mf.Close(); err != nil {
396 log.Printf("closing message file: %v", err)
397 }
398 }()
399 fi, err := mf.Stat()
400 xcheckf(err, "stat %q", path)
401 p, err := message.EnsurePart(jlog, false, mf, fi.Size())
402 if err != nil {
403 log.Printf("bad sent message %q: %s", path, err)
404 return
405 }
406
407 words, err = f.ParseMessage(p)
408 if err != nil {
409 log.Printf("bad sent message %q: %s", path, err)
410 return
411 }
412 }
413
414 if err := f.Train(context.Background(), msg.ham, words); err != nil {
415 log.Printf("train: %s", err)
416 }
417 }
418
419 for _, m := range msgs {
420 play(m)
421 }
422
423 err := f.Save()
424 xcheckf(err, "saving filter")
425
426 fmt.Printf("completed, nham %d, nsent %d, nspam %d, nbad %d, nwithoutdate %d\n", nham, nsent, nspam, nbad, nnodate)
427 fmt.Printf("total ham, ok %d, bad %d\n", nhamok, nhambad)
428 fmt.Printf("total spam, ok %d, bad %d\n", nspamok, nspambad)
429 fmt.Printf("specifity (true negatives, hams identified): %.6f\n", float64(nhamok)/(float64(nhamok+nhambad)))
430 fmt.Printf("sensitivity (true positives, spams identified): %.6f\n", float64(nspamok)/(float64(nspamok+nspambad)))
431 fmt.Printf("accuracy: %.6f\n", float64(nhamok+nspamok)/float64(nhamok+nhambad+nspamok+nspambad))
432}
433