Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pkg/ffdns/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright © 2026 Kaleido, Inc.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ffdns

import (
"time"

"github.com/hyperledger/firefly-common/pkg/config"
)

const (
// Servers an optional list of DNS server addresses (host or host:port, port defaults
// to 53). Setting this forces use of Go's built-in resolver.
DNSServers = "servers"
// Timeout the dial timeout when contacting a configured DNS server
DNSTimeout = "timeout"
)

type Config struct {
Servers []string
Timeout time.Duration
}

func InitConfig(conf config.Section) {
conf.AddKnownKey(DNSServers)
conf.AddKnownKey(DNSTimeout)
}

func GenerateConfig(conf config.Section) (*Config, error) {
return &Config{
Servers: conf.GetStringSlice(DNSServers),
Timeout: conf.GetDuration(DNSTimeout),
}, nil
}
129 changes: 129 additions & 0 deletions pkg/ffdns/ffdns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright © 2026 Kaleido, Inc.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ffdns

import (
"context"
"errors"
"net"

"github.com/hyperledger/firefly-common/pkg/config"
"github.com/hyperledger/firefly-common/pkg/metric"
)

const (
metricsDNSRequestsTotal = "dns_requests_total"
metricsDNSResponsesTotal = "dns_responses_total"
metricsDNSErrorsTotal = "dns_errors_total"
)

var metricsManager metric.MetricsManager

func EnableResolverMetrics(ctx context.Context, metricsRegistry metric.MetricsRegistry) {
if metricsManager != nil {
return
}
metricsManager, _ = metricsRegistry.NewMetricsManagerForSubsystem(ctx, "dns")
metricsManager.NewCounterMetricWithLabels(ctx, metricsDNSRequestsTotal, "DNS requests", []string{"server"}, false)
metricsManager.NewCounterMetricWithLabels(ctx, metricsDNSResponsesTotal, "DNS responses", []string{"server", "status"}, false)
metricsManager.NewCounterMetricWithLabels(ctx, metricsDNSErrorsTotal, "DNS errors", []string{"server", "error"}, false)
}

// NewDNSResolver builds a pure-Go *net.Resolver for metrics instructmentation, custom timeouts, and/or custom servers.
// The resolver will dial the given DNS servers (each host or host:port, port defaulting to 53) in order, failing over to the
// next on error. Returns nil if none of the customizations (metrics, timeout, or servers) are enabeld.
// Exported so non-ffresty dialers — e.g. a WebSocket dialer — can honour the same
// DNS config as the HTTP client.
func NewResolver(config config.Section) *net.Resolver {
cfg, err := GenerateConfig(config)
if err != nil {
return nil
}

return NewResolverWithConfig(cfg)
}

func NewResolverWithConfig(cfg *Config) *net.Resolver {
var servers []string
if len(cfg.Servers) > 0 {
servers = make([]string, len(cfg.Servers))
for i, server := range cfg.Servers {
servers[i] = withDefaultDNSPort(server)
}
}

// If we have nothing to layer on top of the system resolver — no configured servers, no
// dial timeout, and metrics disabled — leave it untouched (callers treat nil as "use the
// system resolver"). Returning a resolver here would force Go's built-in resolver
// (PreferGo) in deployments that haven't opted into any of these.
if len(servers) == 0 && cfg.Timeout <= 0 && metricsManager == nil {
return nil
}

return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{Timeout: cfg.Timeout}
// When no servers are explicitly configured, wrap Go's built-in resolver: it has
// already selected a nameserver from the system config (resolv.conf) and passes it
// as address, so we dial that and still apply our timeout and metrics.
dialServers := servers
if len(dialServers) == 0 {
dialServers = []string{address}
}
var err error
// Go's built-in resolver dials a fresh connection per query exchange (escalating
// from UDP to TCP for truncated responses), so each Dial maps to a DNS request. We
// record metrics at this connection level.
for _, server := range dialServers {
recordDNSMetric(ctx, metricsDNSRequestsTotal, map[string]string{"server": server})
var conn net.Conn
if conn, err = d.DialContext(ctx, network, server); err == nil {
recordDNSMetric(ctx, metricsDNSResponsesTotal, map[string]string{"server": server, "status": "success"})
return conn, nil
}
recordDNSMetric(ctx, metricsDNSErrorsTotal, map[string]string{"server": server, "error": classifyDNSError(err)})
}
return nil, err
},
}
}

// recordDNSMetric increments a DNS counter when resolver metrics have been enabled, and is a no-op otherwise.
func recordDNSMetric(ctx context.Context, name string, labels map[string]string) {
if metricsManager == nil {
return
}
metricsManager.IncCounterMetricWithLabels(ctx, name, labels, nil)
}

// classifyDNSError maps a dial error to a low-cardinality label so the dns_errors_total metric doesn't explode.
func classifyDNSError(err error) string {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return "timeout"
}
return "error"
}

// withDefaultDNSPort ensures a DNS server address has a port, defaulting to 53.
func withDefaultDNSPort(server string) string {
if _, _, err := net.SplitHostPort(server); err == nil {
return server
}
return net.JoinHostPort(server, "53")
}
200 changes: 200 additions & 0 deletions pkg/ffdns/ffdns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// Copyright © 2026 Kaleido, Inc.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ffdns

import (
"context"
"net"
"strings"
"testing"
"time"

"github.com/hyperledger/firefly-common/pkg/config"
"github.com/hyperledger/firefly-common/pkg/metric"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// counterTotal sums the values of all series of a counter whose metric family name ends with
// the given suffix (the registry prefixes names with component + subsystem).
func counterTotal(t *testing.T, mr metric.MetricsRegistry, nameSuffix string) float64 {
families, err := mr.GetGatherer().Gather()
require.NoError(t, err)
var total float64
for _, mf := range families {
if strings.HasSuffix(mf.GetName(), nameSuffix) {
for _, m := range mf.GetMetric() {
if c := m.GetCounter(); c != nil {
total += c.GetValue()
}
}
}
}
return total
}

var utConf = config.RootSection("dns_unit_tests")

func resetConf() {
config.RootConfigReset()
InitConfig(utConf)
}

func TestWithDefaultDNSPort(t *testing.T) {
assert.Equal(t, "8.8.8.8:53", withDefaultDNSPort("8.8.8.8"))
assert.Equal(t, "8.8.8.8:5353", withDefaultDNSPort("8.8.8.8:5353"))
assert.Equal(t, "[2001:db8::1]:53", withDefaultDNSPort("2001:db8::1"))
assert.Equal(t, "[2001:db8::1]:5353", withDefaultDNSPort("[2001:db8::1]:5353"))
}

func TestNewResolverWithConfig(t *testing.T) {
// No servers -> nil, leaving Go's default system resolver selection in place
assert.Nil(t, NewResolverWithConfig(&Config{}))

// Servers configured -> pure-Go resolver
r := NewResolverWithConfig(&Config{Servers: []string{"8.8.8.8"}})
require.NotNil(t, r)
assert.True(t, r.PreferGo)
assert.NotNil(t, r.Dial)
}

func TestNewResolverFromConfigSection(t *testing.T) {
resetConf()
utConf.Set(DNSServers, []string{"8.8.8.8", "1.1.1.1:53"})
r := NewResolver(utConf)
require.NotNil(t, r)
assert.True(t, r.PreferGo)

resetConf()
assert.Nil(t, NewResolver(utConf))
}

func TestResolverDialFailover(t *testing.T) {
// Stand up a listener acting as the "good" DNS server
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()

accepted := make(chan struct{}, 1)
go func() {
conn, acceptErr := ln.Accept()
if acceptErr == nil {
accepted <- struct{}{}
_ = conn.Close()
}
}()

// First server is unroutable so the dialer must fail over to the live listener
r := NewResolverWithConfig(&Config{
Timeout: 5 * time.Second,
Servers: []string{"127.0.0.1:1", ln.Addr().String()},
})
require.NotNil(t, r)

conn, err := r.Dial(context.Background(), "tcp", "ignored:53")
require.NoError(t, err)
defer conn.Close()
assert.Equal(t, ln.Addr().String(), conn.RemoteAddr().String())

select {
case <-accepted:
case <-time.After(5 * time.Second):
t.Fatal("DNS dial did not reach the configured server")
}
}

func TestResolverDialAllFail(t *testing.T) {
r := NewResolverWithConfig(&Config{
Timeout: 250 * time.Millisecond,
Servers: []string{"127.0.0.1:1"},
})
require.NotNil(t, r)
_, err := r.Dial(context.Background(), "tcp", "ignored:53")
assert.Error(t, err)
}

func TestEnableResolverMetrics(t *testing.T) {
metricsManager = nil
defer func() { metricsManager = nil }()

ctx := context.Background()
mr := metric.NewPrometheusMetricsRegistry("test")
EnableResolverMetrics(ctx, mr)
require.NotNil(t, metricsManager)

// Idempotent - a second call is a no-op rather than re-registering
EnableResolverMetrics(ctx, mr)
}

func TestResolverDialRecordsMetrics(t *testing.T) {
metricsManager = nil
defer func() { metricsManager = nil }()

ctx := context.Background()
mr := metric.NewPrometheusMetricsRegistry("test")
EnableResolverMetrics(ctx, mr)

// Live listener acts as the second (good) DNS server; the first is unroutable so a single
// Dial exercises the request, error (failover), and response metric paths together.
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() {
if conn, acceptErr := ln.Accept(); acceptErr == nil {
_ = conn.Close()
}
}()

r := NewResolverWithConfig(&Config{
Timeout: 5 * time.Second,
Servers: []string{"127.0.0.1:1", ln.Addr().String()},
})
require.NotNil(t, r)
conn, err := r.Dial(ctx, "tcp", "ignored:53")
require.NoError(t, err)
defer conn.Close()

assert.GreaterOrEqual(t, counterTotal(t, mr, "dns_requests_total"), float64(2), "one request per server attempted")
assert.GreaterOrEqual(t, counterTotal(t, mr, "dns_responses_total"), float64(1), "one successful response")
assert.GreaterOrEqual(t, counterTotal(t, mr, "dns_errors_total"), float64(1), "first server failed over")
}

func TestResolverDialNoMetricsWhenDisabled(t *testing.T) {
metricsManager = nil // metrics not enabled -> recording is a no-op, no panic
r := NewResolverWithConfig(&Config{
Timeout: 250 * time.Millisecond,
Servers: []string{"127.0.0.1:1"},
})
require.NotNil(t, r)
_, err := r.Dial(context.Background(), "tcp", "ignored:53")
assert.Error(t, err)
}

func TestClassifyDNSError(t *testing.T) {
assert.Equal(t, "error", classifyDNSError(assertAnErr{}))
assert.Equal(t, "timeout", classifyDNSError(timeoutErr{}))
}

type assertAnErr struct{}

func (assertAnErr) Error() string { return "boom" }

type timeoutErr struct{}

func (timeoutErr) Error() string { return "i/o timeout" }
func (timeoutErr) Timeout() bool { return true }
func (timeoutErr) Temporary() bool { return true }
Loading
Loading