summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPavel Zbitskiy <65323360+algorandskiy@users.noreply.github.com>2024-02-13 16:43:32 -0500
committerGitHub <noreply@github.com>2024-02-13 16:43:32 -0500
commitd8c825d96b90b4f1c84a5b3771987ab53bfc8568 (patch)
tree12964e83d5d58182fb93dcca0b1de9b864dbace6
parentcaec33dfe083327994572d77b66a8a0c6d40bc2c (diff)
network: use network context for DNS operations in readFromSRV (#5936)
-rw-r--r--cmd/catchpointdump/net.go2
-rw-r--r--network/wsNetwork.go6
-rw-r--r--network/wsNetwork_test.go2
-rw-r--r--tools/network/bootstrap.go14
-rw-r--r--tools/network/bootstrap_test.go5
-rw-r--r--tools/network/telemetryURIUpdateService.go3
6 files changed, 17 insertions, 15 deletions
diff --git a/cmd/catchpointdump/net.go b/cmd/catchpointdump/net.go
index 2de40c3c2..41e1fd1dd 100644
--- a/cmd/catchpointdump/net.go
+++ b/cmd/catchpointdump/net.go
@@ -78,7 +78,7 @@ var netCmd = &cobra.Command{
if relayAddress != "" {
addrs = []string{relayAddress}
} else {
- addrs, err = tools.ReadFromSRV("algobootstrap", "tcp", networkName, "", false)
+ addrs, err = tools.ReadFromSRV(context.Background(), "algobootstrap", "tcp", networkName, "", false)
if err != nil || len(addrs) == 0 {
reportErrorf("Unable to bootstrap records for '%s' : %v", networkName, err)
}
diff --git a/network/wsNetwork.go b/network/wsNetwork.go
index 5bfabbaf2..92a02976b 100644
--- a/network/wsNetwork.go
+++ b/network/wsNetwork.go
@@ -289,7 +289,7 @@ type WebsocketNetwork struct {
protocolVersion string
// resolveSRVRecords is a function that resolves SRV records for a given service, protocol and name
- resolveSRVRecords func(service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (addrs []string, err error)
+ resolveSRVRecords func(ctx context.Context, service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (addrs []string, err error)
}
const (
@@ -1887,7 +1887,7 @@ func (wn *WebsocketNetwork) mergePrimarySecondaryRelayAddressSlices(network prot
func (wn *WebsocketNetwork) getDNSAddrs(dnsBootstrap string) (relaysAddresses []string, archiverAddresses []string) {
var err error
- relaysAddresses, err = wn.resolveSRVRecords("algobootstrap", "tcp", dnsBootstrap, wn.config.FallbackDNSResolverAddress, wn.config.DNSSecuritySRVEnforced())
+ relaysAddresses, err = wn.resolveSRVRecords(wn.ctx, "algobootstrap", "tcp", dnsBootstrap, wn.config.FallbackDNSResolverAddress, wn.config.DNSSecuritySRVEnforced())
if err != nil {
// only log this warning on testnet or devnet
if wn.NetworkID == config.Devnet || wn.NetworkID == config.Testnet {
@@ -1896,7 +1896,7 @@ func (wn *WebsocketNetwork) getDNSAddrs(dnsBootstrap string) (relaysAddresses []
relaysAddresses = nil
}
- archiverAddresses, err = wn.resolveSRVRecords("archive", "tcp", dnsBootstrap, wn.config.FallbackDNSResolverAddress, wn.config.DNSSecuritySRVEnforced())
+ archiverAddresses, err = wn.resolveSRVRecords(wn.ctx, "archive", "tcp", dnsBootstrap, wn.config.FallbackDNSResolverAddress, wn.config.DNSSecuritySRVEnforced())
if err != nil {
// only log this warning on testnet or devnet
if wn.NetworkID == config.Devnet || wn.NetworkID == config.Testnet {
diff --git a/network/wsNetwork_test.go b/network/wsNetwork_test.go
index e35cc7d17..ef5769bcf 100644
--- a/network/wsNetwork_test.go
+++ b/network/wsNetwork_test.go
@@ -4156,7 +4156,7 @@ func TestRefreshRelayArchivePhonebookAddresses(t *testing.T) {
}
// Mock the SRV record lookup
- netA.resolveSRVRecords = func(service string, protocol string, name string, fallbackDNSResolverAddress string,
+ netA.resolveSRVRecords = func(ctx context.Context, service string, protocol string, name string, fallbackDNSResolverAddress string,
secure bool) (addrs []string, err error) {
if service == "algobootstrap" && protocol == "tcp" && name == primarySRVBootstrap {
return primaryRelayResolvedRecords, nil
diff --git a/tools/network/bootstrap.go b/tools/network/bootstrap.go
index f04c67528..d30ae4bda 100644
--- a/tools/network/bootstrap.go
+++ b/tools/network/bootstrap.go
@@ -24,7 +24,7 @@ import (
"github.com/algorand/go-algorand/logging"
)
-func readFromSRV(service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (records []*net.SRV, err error) {
+func readFromSRV(ctx context.Context, service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (records []*net.SRV, err error) {
log := logging.Base()
if name == "" {
log.Debug("no dns lookup due to empty name")
@@ -38,14 +38,14 @@ func readFromSRV(service string, protocol string, name string, fallbackDNSResolv
controller := NewResolveController(secure, fallbackDNSResolverAddress, log)
systemResolver := controller.SystemResolver()
- _, records, sysLookupErr := systemResolver.LookupSRV(context.Background(), service, protocol, name)
+ _, records, sysLookupErr := systemResolver.LookupSRV(ctx, service, protocol, name)
if sysLookupErr != nil {
log.Infof("ReadFromBootstrap: DNS LookupSRV failed when using system resolver: %v", sysLookupErr)
var fallbackLookupErr error
if fallbackDNSResolverAddress != "" {
fallbackResolver := controller.FallbackResolver()
- _, records, fallbackLookupErr = fallbackResolver.LookupSRV(context.Background(), service, protocol, name)
+ _, records, fallbackLookupErr = fallbackResolver.LookupSRV(ctx, service, protocol, name)
}
if fallbackLookupErr != nil {
log.Infof("ReadFromBootstrap: DNS LookupSRV failed when using fallback '%s' resolver: %v", fallbackDNSResolverAddress, fallbackLookupErr)
@@ -54,7 +54,7 @@ func readFromSRV(service string, protocol string, name string, fallbackDNSResolv
if fallbackLookupErr != nil || fallbackDNSResolverAddress == "" {
fallbackResolver := controller.DefaultResolver()
var defaultLookupErr error
- _, records, defaultLookupErr = fallbackResolver.LookupSRV(context.Background(), service, protocol, name)
+ _, records, defaultLookupErr = fallbackResolver.LookupSRV(ctx, service, protocol, name)
if defaultLookupErr != nil {
err = fmt.Errorf("ReadFromBootstrap: DNS LookupSRV failed when using system resolver(%v), fallback resolver(%v), as well as using default resolver due to %v", sysLookupErr, fallbackLookupErr, defaultLookupErr)
return
@@ -65,8 +65,8 @@ func readFromSRV(service string, protocol string, name string, fallbackDNSResolv
}
// ReadFromSRV is a helper to collect SRV addresses for a given name
-func ReadFromSRV(service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (addrs []string, err error) {
- records, err := readFromSRV(service, protocol, name, fallbackDNSResolverAddress, secure)
+func ReadFromSRV(ctx context.Context, service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (addrs []string, err error) {
+ records, err := readFromSRV(ctx, service, protocol, name, fallbackDNSResolverAddress, secure)
if err != nil {
return addrs, err
}
@@ -88,7 +88,7 @@ func ReadFromSRV(service string, protocol string, name string, fallbackDNSResolv
// ReadFromSRVPriority is a helper to collect SRV addresses with priorities for a given name
func ReadFromSRVPriority(service string, protocol string, name string, fallbackDNSResolverAddress string, secure bool) (prioAddrs map[uint16][]string, err error) {
- records, err := readFromSRV(service, protocol, name, fallbackDNSResolverAddress, secure)
+ records, err := readFromSRV(context.Background(), service, protocol, name, fallbackDNSResolverAddress, secure)
if err != nil {
return prioAddrs, err
}
diff --git a/tools/network/bootstrap_test.go b/tools/network/bootstrap_test.go
index 155615c70..a24bea422 100644
--- a/tools/network/bootstrap_test.go
+++ b/tools/network/bootstrap_test.go
@@ -17,6 +17,7 @@
package network
import (
+ "context"
"testing"
"github.com/algorand/go-algorand/test/partitiontest"
@@ -55,10 +56,10 @@ func TestReadFromSRV(t *testing.T) {
fallback := ""
secure := true
- addrs, err := ReadFromSRV("", protocol, name, fallback, secure)
+ addrs, err := ReadFromSRV(context.Background(), "", protocol, name, fallback, secure)
require.Error(t, err)
- addrs, err = ReadFromSRV(service, protocol, name, fallback, secure)
+ addrs, err = ReadFromSRV(context.Background(), service, protocol, name, fallback, secure)
require.NoError(t, err)
require.GreaterOrEqual(t, len(addrs), 1)
addr := addrs[0]
diff --git a/tools/network/telemetryURIUpdateService.go b/tools/network/telemetryURIUpdateService.go
index 66dd87dd0..2b4e61426 100644
--- a/tools/network/telemetryURIUpdateService.go
+++ b/tools/network/telemetryURIUpdateService.go
@@ -17,6 +17,7 @@
package network
import (
+ "context"
"net/url"
"strings"
"time"
@@ -132,5 +133,5 @@ func (t *telemetryURIUpdater) lookupTelemetryURL() (url *url.URL) {
}
func (t *telemetryURIUpdater) readFromSRV(protocol string, bootstrapID string) (addrs []string, err error) {
- return ReadFromSRV("telemetry", protocol, bootstrapID, t.cfg.FallbackDNSResolverAddress, t.cfg.DNSSecuritySRVEnforced())
+ return ReadFromSRV(context.Background(), "telemetry", protocol, bootstrapID, t.cfg.FallbackDNSResolverAddress, t.cfg.DNSSecuritySRVEnforced())
}