|
5 | 5 | "fmt"
|
6 | 6 | "net/netip"
|
7 | 7 | "strconv"
|
| 8 | + "strings" |
8 | 9 | "time"
|
9 | 10 |
|
10 | 11 | "github.com/qdm12/ddns-updater/internal/constants"
|
@@ -51,39 +52,53 @@ func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
|
51 | 52 | }
|
52 | 53 | }
|
53 | 54 |
|
54 |
| -func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries int) ( |
| 55 | +func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries uint) ( |
55 | 56 | ipv4, ipv6 []netip.Addr, err error,
|
56 | 57 | ) {
|
57 |
| - for range tries { |
58 |
| - ipv4, ipv6, err = s.lookupIPs(ctx, hostname) |
59 |
| - if err == nil { |
60 |
| - return ipv4, ipv6, nil |
61 |
| - } |
| 58 | + type result struct { |
| 59 | + network string |
| 60 | + ips []netip.Addr |
| 61 | + err error |
62 | 62 | }
|
63 |
| - return nil, nil, err |
64 |
| -} |
65 |
| - |
66 |
| -func (s *Service) lookupIPs(ctx context.Context, hostname string) ( |
67 |
| - ipv4, ipv6 []netip.Addr, err error, |
68 |
| -) { |
69 |
| - ips, err := s.resolver.LookupNetIP(ctx, "ip", hostname) |
70 |
| - if err != nil { |
71 |
| - return nil, nil, err |
| 63 | + results := make(chan result) |
| 64 | + networks := []string{"ip4", "ip6"} |
| 65 | + lookupCtx, cancel := context.WithCancel(ctx) |
| 66 | + for _, network := range networks { |
| 67 | + go func(ctx context.Context, network string, results chan<- result) { |
| 68 | + for range tries { |
| 69 | + ips, err := s.resolver.LookupNetIP(ctx, network, hostname) |
| 70 | + if err != nil { |
| 71 | + if strings.HasSuffix(err.Error(), "no such host") { |
| 72 | + results <- result{network: network} // no IP address for this network |
| 73 | + return |
| 74 | + } |
| 75 | + continue // retry |
| 76 | + } |
| 77 | + results <- result{network: network, ips: ips, err: err} |
| 78 | + return |
| 79 | + } |
| 80 | + }(lookupCtx, network, results) |
72 | 81 | }
|
73 | 82 |
|
74 |
| - ipv4 = make([]netip.Addr, 0, len(ips)) |
75 |
| - ipv6 = make([]netip.Addr, 0, len(ips)) |
76 |
| - for _, ip := range ips { |
77 |
| - switch { |
78 |
| - case !ip.IsValid(): |
79 |
| - case ip.Is4(): |
80 |
| - ipv4 = append(ipv4, ip) |
81 |
| - default: // IPv6 |
82 |
| - ipv6 = append(ipv6, ip) |
| 83 | + for range networks { |
| 84 | + result := <-results |
| 85 | + if result.err != nil { |
| 86 | + if err == nil { |
| 87 | + cancel() |
| 88 | + err = fmt.Errorf("looking up %s addresses: %w", result.network, result.err) |
| 89 | + } |
| 90 | + continue |
| 91 | + } |
| 92 | + switch result.network { |
| 93 | + case "ip4": |
| 94 | + ipv4 = result.ips |
| 95 | + case "ip6": |
| 96 | + ipv6 = result.ips |
83 | 97 | }
|
84 | 98 | }
|
| 99 | + cancel() |
85 | 100 |
|
86 |
| - return ipv4, ipv6, nil |
| 101 | + return ipv4, ipv6, err |
87 | 102 | }
|
88 | 103 |
|
89 | 104 | func doIPVersion(records []librecords.Record) (doIP, doIPv4, doIPv6 bool) {
|
|
0 commit comments