-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathcache.go
More file actions
454 lines (371 loc) · 11.9 KB
/
cache.go
File metadata and controls
454 lines (371 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
// SPDX-FileCopyrightText: The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package mdns
import (
"math/rand"
"strings"
"sync"
"time"
"golang.org/x/net/dns/dnsmessage"
)
// rrClassCacheFlush is the cache-flush bit in the rrclass field of a
// resource record (RFC 6762 §10.2).
const rrClassCacheFlush = 1 << 15
// goodbyeTTL is the time-to-live for goodbye packets (RFC 6762 §10.1).
// Records with TTL=0 are retained for one second before removal.
const goodbyeTTL = 1 * time.Second
// cacheFlushDelay is the grace period before flushing stale records
// after receiving a cache-flush response (RFC 6762 §10.2).
const cacheFlushDelay = 1 * time.Second
// maxRecordTTL caps the lifetime of cached records. RFC 6762 §10
// recommends TTLs of 75 minutes or less; clamping protects against
// hostile or misconfigured responders advertising near-immortal records.
const maxRecordTTL = 75 * time.Minute
// maxCacheEntries caps the total number of cached records so that a busy
// or hostile network cannot grow the cache without bound. When full, the
// entry closest to expiry is evicted to make room.
const maxCacheEntries = 4096
// cacheKey identifies a set of records by lowercased name, type, and
// class (with the cache-flush bit masked).
type cacheKey struct {
name string
rrType dnsmessage.Type
rrClass dnsmessage.Class
}
// cacheEntry stores one cached resource record with timing metadata.
type cacheEntry struct {
resource dnsmessage.Resource
createdAt time.Time
expiresAt time.Time
originalTTL time.Duration
refreshJitter float64
refreshesSent uint8
}
// cache is a thread-safe mDNS record cache (RFC 6762 §10).
type cache struct {
mu sync.RWMutex
entries map[cacheKey][]cacheEntry
size int
now func() time.Time
}
// newCache creates a cache with the given clock function.
func newCache(now func() time.Time) *cache {
return &cache{
entries: make(map[cacheKey][]cacheEntry),
now: now,
}
}
// makeCacheKey builds a cache key from a resource header.
// The name is lowercased for case-insensitive matching.
func makeCacheKey(hdr dnsmessage.ResourceHeader) cacheKey {
return cacheKey{
name: strings.ToLower(hdr.Name.String()),
rrType: hdr.Type,
rrClass: hdr.Class,
}
}
// insert adds or updates a record in the cache. It handles goodbye
// packets (TTL=0, §10.1), the cache-flush bit (§10.2), and normal
// insert/update operations.
func (c *cache) insert(res dnsmessage.Resource, receivedAt time.Time) {
c.mu.Lock()
defer c.mu.Unlock()
hasCacheFlush := res.Header.Class&rrClassCacheFlush != 0
res.Header.Class &^= rrClassCacheFlush
key := makeCacheKey(res.Header)
if res.Header.TTL == 0 {
c.insertGoodbye(key, res, receivedAt)
return
}
if hasCacheFlush {
c.applyCacheFlush(key, receivedAt)
}
c.insertOrUpdate(key, res, receivedAt)
}
// insertGoodbye handles a record with TTL=0 (RFC 6762 §10.1).
// If a matching record exists, its expiry is set to now+1s.
// Otherwise a new entry is created with a 1s TTL. Goodbye entries
// keep originalTTL=0 so they are never refresh candidates.
func (c *cache) insertGoodbye(key cacheKey, res dnsmessage.Resource, receivedAt time.Time) {
entries := c.entries[key]
for idx := range entries {
if resourceDataEqual(entries[idx].resource, res) {
entries[idx].expiresAt = receivedAt.Add(goodbyeTTL)
entries[idx].originalTTL = 0
return
}
}
// Unknown record: insert with 1s TTL so that late listeners see
// the goodbye briefly before it expires.
res.Header.TTL = 1
c.appendEntry(key, cacheEntry{
resource: res,
createdAt: receivedAt,
expiresAt: receivedAt.Add(goodbyeTTL),
})
}
// applyCacheFlush marks old entries for the same key as expiring in 1s
// (RFC 6762 §10.2). Entries created within the last second are preserved.
func (c *cache) applyCacheFlush(key cacheKey, receivedAt time.Time) {
entries := c.entries[key]
deadline := receivedAt.Add(-cacheFlushDelay)
for idx := range entries {
if entries[idx].createdAt.Before(deadline) {
entries[idx].expiresAt = receivedAt.Add(cacheFlushDelay)
}
}
}
// insertOrUpdate adds a new record or updates the TTL of an existing
// record with the same rdata. TTLs are clamped to maxRecordTTL.
func (c *cache) insertOrUpdate(key cacheKey, res dnsmessage.Resource, receivedAt time.Time) {
ttl := min(time.Duration(res.Header.TTL)*time.Second, maxRecordTTL)
entries := c.entries[key]
for idx := range entries {
if resourceDataEqual(entries[idx].resource, res) {
entries[idx].resource.Header.TTL = res.Header.TTL
entries[idx].createdAt = receivedAt
entries[idx].expiresAt = receivedAt.Add(ttl)
entries[idx].originalTTL = ttl
entries[idx].refreshJitter = newRefreshJitter()
entries[idx].refreshesSent = 0
return
}
}
c.appendEntry(key, cacheEntry{
resource: res,
createdAt: receivedAt,
expiresAt: receivedAt.Add(ttl),
originalTTL: ttl,
refreshJitter: newRefreshJitter(),
})
}
// appendEntry adds a new entry under key, evicting the entry closest to
// expiry when the cache is at capacity.
func (c *cache) appendEntry(key cacheKey, entry cacheEntry) {
if c.size >= maxCacheEntries {
c.evictSoonestExpiring()
}
c.entries[key] = append(c.entries[key], entry)
c.size++
}
// evictSoonestExpiring removes the entry with the earliest expiry.
func (c *cache) evictSoonestExpiring() {
var victimKey cacheKey
victimIdx := -1
var victimExpiry time.Time
for key, entries := range c.entries {
for idx := range entries {
if victimIdx == -1 || entries[idx].expiresAt.Before(victimExpiry) {
victimKey = key
victimIdx = idx
victimExpiry = entries[idx].expiresAt
}
}
}
if victimIdx == -1 {
return
}
entries := c.entries[victimKey]
entries = append(entries[:victimIdx], entries[victimIdx+1:]...)
if len(entries) == 0 {
delete(c.entries, victimKey)
} else {
c.entries[victimKey] = entries
}
c.size--
}
// lookup returns non-expired records matching the name, type, and class.
// The name match is case-insensitive. Each returned record has its TTL
// set to the remaining time before expiry.
func (c *cache) lookup(name string, rrType dnsmessage.Type, rrClass dnsmessage.Class) []dnsmessage.Resource {
c.mu.RLock()
defer c.mu.RUnlock()
key := cacheKey{
name: strings.ToLower(name),
rrType: rrType,
rrClass: rrClass,
}
now := c.now()
var results []dnsmessage.Resource
for _, entry := range c.entries[key] {
if !now.Before(entry.expiresAt) {
continue
}
res := entry.resource
remaining := entry.expiresAt.Sub(now)
res.Header.TTL = uint32(remaining / time.Second) //nolint:gosec // remaining is positive after expiry check
results = append(results, res)
}
return results
}
// sweep removes all expired entries from the cache.
func (c *cache) sweep() {
c.mu.Lock()
defer c.mu.Unlock()
now := c.now()
size := 0
for key, entries := range c.entries {
var alive []cacheEntry
for _, entry := range entries {
if now.Before(entry.expiresAt) {
alive = append(alive, entry)
}
}
if len(alive) == 0 {
delete(c.entries, key)
} else {
c.entries[key] = alive
size += len(alive)
}
}
c.size = size
}
// refreshThresholds returns the fractions of TTL at which refresh queries
// are sent (RFC 6762 §5.2): 80%, 85%, 90%, 95%.
func refreshThresholds() [4]float64 {
return [4]float64{0.80, 0.85, 0.90, 0.95}
}
// maxRefreshJitter is the maximum random jitter added to each threshold.
const maxRefreshJitter = 0.02
// newRefreshJitter returns a random 0-2% addition to the next refresh
// threshold (RFC 6762 §5.2). Rolled once per threshold so that repeated
// polling does not bias refreshes toward the unjittered threshold.
func newRefreshJitter() float64 {
return rand.Float64() * maxRefreshJitter //nolint:gosec // weak random is fine for jitter
}
// takeRefreshCandidates consumes and returns the keys with entries that
// have reached their next refresh threshold (RFC 6762 §5.2). Each returned
// candidate has its refreshesSent counter incremented under the cache lock,
// so a key is handed out at most once per threshold; the caller is expected
// to send the refresh query. Duplicate keys are checked once and a key is
// returned at most once per call, even when several entries under it are
// due (one question refreshes all records for the key). Entries with
// originalTTL = 0 (goodbyes) are skipped and at most four refreshes are
// handed out per entry per TTL.
func (c *cache) takeRefreshCandidates(keys []cacheKey) []cacheKey {
c.mu.Lock()
defer c.mu.Unlock()
now := c.now()
thresholds := refreshThresholds()
var candidates []cacheKey
seen := make(map[cacheKey]struct{}, len(keys))
for _, key := range keys {
if _, dup := seen[key]; dup {
continue
}
seen[key] = struct{}{}
entries := c.entries[key]
keyAdded := false
for idx := range entries {
entry := &entries[idx]
if entry.originalTTL <= 0 || int(entry.refreshesSent) >= len(thresholds) {
continue
}
if !now.Before(entry.expiresAt) {
continue
}
startedAt := entry.expiresAt.Add(-entry.originalTTL)
fraction := float64(now.Sub(startedAt)) / float64(entry.originalTTL)
if fraction >= thresholds[entry.refreshesSent]+entry.refreshJitter {
if !keyAdded {
candidates = append(candidates, key)
keyAdded = true
}
entry.refreshesSent++
entry.refreshJitter = newRefreshJitter()
}
}
}
return candidates
}
// flushAll drops all entries (RFC 6762 §10.3 topology change).
func (c *cache) flushAll() {
c.mu.Lock()
defer c.mu.Unlock()
c.entries = make(map[cacheKey][]cacheEntry)
c.size = 0
}
// reduceTTLs caps the remaining TTL of every entry to maxRemaining.
// Entries that already expire sooner are left unchanged.
func (c *cache) reduceTTLs(maxRemaining time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
now := c.now()
deadline := now.Add(maxRemaining)
for _, entries := range c.entries {
for idx := range entries {
if entries[idx].expiresAt.After(deadline) {
entries[idx].expiresAt = deadline
}
}
}
}
// len returns the count of non-expired entries.
func (c *cache) len() int {
c.mu.RLock()
defer c.mu.RUnlock()
now := c.now()
count := 0
for _, entries := range c.entries {
for _, entry := range entries {
if now.Before(entry.expiresAt) {
count++
}
}
}
return count
}
// resourceDataEqual reports whether two resources have the same name,
// type, and body content. Name comparison is case-insensitive.
func resourceDataEqual(resA, resB dnsmessage.Resource) bool {
if !strings.EqualFold(resA.Header.Name.String(), resB.Header.Name.String()) {
return false
}
if resA.Header.Type != resB.Header.Type {
return false
}
return resourceBodyEqual(resA.Body, resB.Body)
}
// resourceBodyEqual compares two resource bodies by type-specific fields.
//
//nolint:cyclop
func resourceBodyEqual(bodyA, bodyB dnsmessage.ResourceBody) bool {
switch valA := bodyA.(type) {
case *dnsmessage.AResource:
valB, ok := bodyB.(*dnsmessage.AResource)
return ok && valA.A == valB.A
case *dnsmessage.AAAAResource:
valB, ok := bodyB.(*dnsmessage.AAAAResource)
return ok && valA.AAAA == valB.AAAA
case *dnsmessage.PTRResource:
valB, ok := bodyB.(*dnsmessage.PTRResource)
return ok && strings.EqualFold(valA.PTR.String(), valB.PTR.String())
case *dnsmessage.SRVResource:
valB, ok := bodyB.(*dnsmessage.SRVResource)
return ok && srvFieldsEqual(valA, valB)
case *dnsmessage.TXTResource:
valB, ok := bodyB.(*dnsmessage.TXTResource)
return ok && txtSlicesEqual(valA.TXT, valB.TXT)
default:
return false
}
}
// srvFieldsEqual compares two SRV resource bodies field by field.
func srvFieldsEqual(resA, resB *dnsmessage.SRVResource) bool {
return strings.EqualFold(resA.Target.String(), resB.Target.String()) &&
resA.Port == resB.Port &&
resA.Priority == resB.Priority &&
resA.Weight == resB.Weight
}
// txtSlicesEqual reports whether two TXT string slices are identical.
func txtSlicesEqual(sliceA, sliceB []string) bool {
if len(sliceA) != len(sliceB) {
return false
}
for idx := range sliceA {
if sliceA[idx] != sliceB[idx] {
return false
}
}
return true
}