From a78fd69f8080a1dc0aeb5dc0956267225da9156e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 23 Nov 2022 13:39:42 +0100 Subject: [PATCH] Feature/dns client configuration (#563) Added host configurators for Linux, Windows, and macOS. The host configurator will update the peer system configuration directing DNS queries according to its capabilities. Some Linux distributions don't support split (match) DNS or custom ports, and that will be reported to our management system in another PR --- client/cmd/service.go | 1 + client/cmd/service_installer.go | 10 +- client/internal/dns/dbus_linux.go | 41 +++ client/internal/dns/file_linux.go | 154 ++++++++++ client/internal/dns/host.go | 79 +++++ client/internal/dns/host_darwin.go | 259 ++++++++++++++++ client/internal/dns/host_linux.go | 75 +++++ client/internal/dns/host_windows.go | 260 ++++++++++++++++ client/internal/dns/local.go | 30 +- client/internal/dns/local_test.go | 2 +- client/internal/dns/network_manager_linux.go | 295 +++++++++++++++++++ client/internal/dns/server.go | 76 ++++- client/internal/dns/server_test.go | 51 +++- client/internal/dns/systemd_linux.go | 185 ++++++++++++ client/internal/dns/upstream.go | 2 +- client/internal/engine.go | 9 +- client/internal/engine_test.go | 5 + dns/dns.go | 25 ++ go.mod | 2 +- go.sum | 3 +- iface/iface.go | 6 +- iface/iface_unix.go | 5 + iface/iface_windows.go | 14 + 23 files changed, 1552 insertions(+), 37 deletions(-) create mode 100644 client/internal/dns/dbus_linux.go create mode 100644 client/internal/dns/file_linux.go create mode 100644 client/internal/dns/host.go create mode 100644 client/internal/dns/host_darwin.go create mode 100644 client/internal/dns/host_linux.go create mode 100644 client/internal/dns/host_windows.go create mode 100644 client/internal/dns/network_manager_linux.go create mode 100644 client/internal/dns/systemd_linux.go diff --git a/client/cmd/service.go b/client/cmd/service.go index 7a6729850..18fe5d621 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -32,6 +32,7 @@ func newSVCConfig() *service.Config { Name: name, DisplayName: "Netbird", Description: "A WireGuard-based mesh network that connects your devices into a single private network.", + Option: make(service.KeyValue), } } diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 86439ad17..8efb5ee60 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "path/filepath" "runtime" "github.com/spf13/cobra" @@ -32,8 +33,13 @@ var installCmd = &cobra.Command{ } if managementURL != "" { - svcConfig.Arguments = append(svcConfig.Arguments, "--management-url") - svcConfig.Arguments = append(svcConfig.Arguments, managementURL) + svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) + } + + if logFile != "console" { + svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) + svcConfig.Option["LogOutput"] = true + svcConfig.Option["LogDirectory"] = filepath.Dir(logFile) } if runtime.GOOS == "linux" { diff --git a/client/internal/dns/dbus_linux.go b/client/internal/dns/dbus_linux.go new file mode 100644 index 000000000..0f6d4156a --- /dev/null +++ b/client/internal/dns/dbus_linux.go @@ -0,0 +1,41 @@ +package dns + +import ( + "context" + "github.com/godbus/dbus/v5" + log "github.com/sirupsen/logrus" + "time" +) + +const dbusDefaultFlag = 0 + +func isDbusListenerRunning(dest string, path dbus.ObjectPath) bool { + obj, closeConn, err := getDbusObject(dest, path) + if err != nil { + return false + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0).Store() + return err == nil +} + +func getDbusObject(dest string, path dbus.ObjectPath) (dbus.BusObject, func(), error) { + conn, err := dbus.SystemBus() + if err != nil { + return nil, nil, err + } + obj := conn.Object(dest, path) + + closeFunc := func() { + closeErr := conn.Close() + if closeErr != nil { + log.Warnf("got an error closing dbus connection, err: %s", closeErr) + } + } + + return obj, closeFunc, nil +} diff --git a/client/internal/dns/file_linux.go b/client/internal/dns/file_linux.go new file mode 100644 index 000000000..ad6e3a37f --- /dev/null +++ b/client/internal/dns/file_linux.go @@ -0,0 +1,154 @@ +package dns + +import ( + "bytes" + "fmt" + log "github.com/sirupsen/logrus" + "os" +) + +const ( + fileGeneratedResolvConfContentHeader = "# Generated by NetBird" + fileGeneratedResolvConfSearchBeginContent = "search " + fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader + + "\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" + + fileGeneratedResolvConfSearchBeginContent + "%s\n" +) +const ( + fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird" + fileMaxLineCharsLimit = 256 + fileMaxNumberOfSearchDomains = 6 +) + +var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent) + +type fileConfigurator struct { + originalPerms os.FileMode +} + +func newFileConfigurator() (hostManager, error) { + return &fileConfigurator{}, nil +} + +func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error { + backupFileExist := false + _, err := os.Stat(fileDefaultResolvConfBackupLocation) + if err == nil { + backupFileExist = true + } + + if !config.routeAll { + if backupFileExist { + err = f.restore() + if err != nil { + return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group. Restoring the original file return err: %s", err) + } + } + return fmt.Errorf("unable to configure DNS for this peer using file manager without a Primary nameserver group") + } + managerType, err := getOSDNSManagerType() + if err != nil { + return err + } + switch managerType { + case fileManager, netbirdManager: + if !backupFileExist { + err = f.backup() + if err != nil { + return fmt.Errorf("unable to backup the resolv.conf file") + } + } + default: + // todo improve this and maybe restart DNS manager from scratch + return fmt.Errorf("something happened and file manager is not your prefered host dns configurator, restart the agent") + } + + var searchDomains string + appendedDomains := 0 + for _, dConf := range config.domains { + if dConf.matchOnly { + continue + } + if appendedDomains >= fileMaxNumberOfSearchDomains { + // lets log all skipped domains + log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain) + continue + } + if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit { + // lets log all skipped domains + log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain) + continue + } + + searchDomains += " " + dConf.domain + appendedDomains++ + } + content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains) + err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms) + if err != nil { + err = f.restore() + if err != nil { + log.Errorf("attempt to restore default file failed with error: %s", err) + } + return err + } + log.Infof("created a NetBird managed %s file with your DNS settings", defaultResolvConfPath) + return nil +} + +func (f *fileConfigurator) restoreHostDNS() error { + return f.restore() +} + +func (f *fileConfigurator) backup() error { + stats, err := os.Stat(defaultResolvConfPath) + if err != nil { + return fmt.Errorf("got an error while checking stats for %s file. Error: %s", defaultResolvConfPath, err) + } + + f.originalPerms = stats.Mode() + + err = copyFile(defaultResolvConfPath, fileDefaultResolvConfBackupLocation) + if err != nil { + return fmt.Errorf("got error while backing up the %s file. Error: %s", defaultResolvConfPath, err) + } + return nil +} + +func (f *fileConfigurator) restore() error { + err := copyFile(fileDefaultResolvConfBackupLocation, defaultResolvConfPath) + if err != nil { + return fmt.Errorf("got error while restoring the %s file from %s. Error: %s", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) + } + + return os.RemoveAll(fileDefaultResolvConfBackupLocation) +} + +func writeDNSConfig(content, fileName string, permissions os.FileMode) error { + log.Debugf("creating managed file %s", fileName) + var buf bytes.Buffer + buf.WriteString(content) + err := os.WriteFile(fileName, buf.Bytes(), permissions) + if err != nil { + return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err) + } + return nil +} + +func copyFile(src, dest string) error { + stats, err := os.Stat(src) + if err != nil { + return fmt.Errorf("got an error while checking stats for %s file when copying it. Error: %s", src, err) + } + + bytesRead, err := os.ReadFile(src) + if err != nil { + return fmt.Errorf("got an error while reading the file %s file for copy. Error: %s", src, err) + } + + err = os.WriteFile(dest, bytesRead, stats.Mode()) + if err != nil { + return fmt.Errorf("got an writing the destination file %s for copy. Error: %s", dest, err) + } + return nil +} diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go new file mode 100644 index 000000000..c077e2032 --- /dev/null +++ b/client/internal/dns/host.go @@ -0,0 +1,79 @@ +package dns + +import ( + "fmt" + nbdns "github.com/netbirdio/netbird/dns" + "strings" +) + +type hostManager interface { + applyDNSConfig(config hostDNSConfig) error + restoreHostDNS() error +} + +type hostDNSConfig struct { + domains []domainConfig + routeAll bool + serverIP string + serverPort int +} + +type domainConfig struct { + domain string + matchOnly bool +} + +type mockHostConfigurator struct { + applyDNSConfigFunc func(config hostDNSConfig) error + restoreHostDNSFunc func() error +} + +func (m *mockHostConfigurator) applyDNSConfig(config hostDNSConfig) error { + if m.applyDNSConfigFunc != nil { + return m.applyDNSConfigFunc(config) + } + return fmt.Errorf("method applyDNSSettings is not implemented") +} + +func (m *mockHostConfigurator) restoreHostDNS() error { + if m.restoreHostDNSFunc != nil { + return m.restoreHostDNSFunc() + } + return fmt.Errorf("method restoreHostDNS is not implemented") +} + +func newNoopHostMocker() hostManager { + return &mockHostConfigurator{ + applyDNSConfigFunc: func(config hostDNSConfig) error { return nil }, + restoreHostDNSFunc: func() error { return nil }, + } +} + +func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) hostDNSConfig { + config := hostDNSConfig{ + routeAll: false, + serverIP: ip, + serverPort: port, + } + for _, nsConfig := range dnsConfig.NameServerGroups { + if nsConfig.Primary { + config.routeAll = true + } + + for _, domain := range nsConfig.Domains { + config.domains = append(config.domains, domainConfig{ + domain: strings.TrimSuffix(domain, "."), + matchOnly: true, + }) + } + } + + for _, customZone := range dnsConfig.CustomZones { + config.domains = append(config.domains, domainConfig{ + domain: strings.TrimSuffix(customZone.Domain, "."), + matchOnly: false, + }) + } + + return config +} diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go new file mode 100644 index 000000000..546561d88 --- /dev/null +++ b/client/internal/dns/host_darwin.go @@ -0,0 +1,259 @@ +package dns + +import ( + "bufio" + "bytes" + "fmt" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "os/exec" + "strconv" + "strings" +) + +const ( + netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" + globalIPv4State = "State:/Network/Global/IPv4" + primaryServiceSetupKeyFormat = "Setup:/Network/Service/%s/DNS" + keySupplementalMatchDomains = "SupplementalMatchDomains" + keySupplementalMatchDomainsNoSearch = "SupplementalMatchDomainsNoSearch" + keyServerAddresses = "ServerAddresses" + keyServerPort = "ServerPort" + arraySymbol = "* " + digitSymbol = "# " + scutilPath = "/usr/sbin/scutil" + searchSuffix = "Search" + matchSuffix = "Match" +) + +type systemConfigurator struct { + // primaryServiceID primary interface in the system. AKA the interface with the default route + primaryServiceID string + createdKeys map[string]struct{} +} + +func newHostManager(_ *iface.WGIface) (hostManager, error) { + return &systemConfigurator{ + createdKeys: make(map[string]struct{}), + }, nil +} + +func (s *systemConfigurator) applyDNSConfig(config hostDNSConfig) error { + var err error + + if config.routeAll { + err = s.addDNSSetupForAll(config.serverIP, config.serverPort) + if err != nil { + return err + } + } else if s.primaryServiceID != "" { + err = s.removeKeyFromSystemConfig(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) + if err != nil { + return err + } + s.primaryServiceID = "" + log.Infof("removed %s:%d as main DNS resolver for this peer", config.serverIP, config.serverPort) + } + + var ( + searchDomains []string + matchDomains []string + ) + + for _, dConf := range config.domains { + if dConf.matchOnly { + matchDomains = append(matchDomains, dConf.domain) + continue + } + searchDomains = append(searchDomains, dConf.domain) + } + + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + if len(matchDomains) != 0 { + err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.serverIP, config.serverPort) + } else { + log.Infof("removing match domains from the system") + err = s.removeKeyFromSystemConfig(matchKey) + } + if err != nil { + return err + } + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + if len(searchDomains) != 0 { + err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.serverIP, config.serverPort) + } else { + log.Infof("removing search domains from the system") + err = s.removeKeyFromSystemConfig(searchKey) + } + if err != nil { + return err + } + + return nil +} + +func (s *systemConfigurator) restoreHostDNS() error { + lines := "" + for key := range s.createdKeys { + lines += buildRemoveKeyOperation(key) + keyType := "search" + if strings.Contains(key, matchSuffix) { + keyType = "match" + } + log.Infof("removing %s domains from system", keyType) + } + if s.primaryServiceID != "" { + lines += buildRemoveKeyOperation(getKeyWithInput(primaryServiceSetupKeyFormat, s.primaryServiceID)) + log.Infof("restoring DNS resolver configuration for system") + } + _, err := runSystemConfigCommand(wrapCommand(lines)) + if err != nil { + log.Errorf("got an error while cleaning the system configuration: %s", err) + return err + } + + return nil +} + +func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { + line := buildRemoveKeyOperation(key) + _, err := runSystemConfigCommand(wrapCommand(line)) + if err != nil { + return err + } + + delete(s.createdKeys, key) + + return nil +} + +func (s *systemConfigurator) addSearchDomains(key, domains string, ip string, port int) error { + err := s.addDNSState(key, domains, ip, port, true) + if err != nil { + return err + } + + log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains) + + s.createdKeys[key] = struct{}{} + + return nil +} + +func (s *systemConfigurator) addMatchDomains(key, domains, dnsServer string, port int) error { + err := s.addDNSState(key, domains, dnsServer, port, false) + if err != nil { + return err + } + + log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains) + + s.createdKeys[key] = struct{}{} + + return nil +} + +func (s *systemConfigurator) addDNSState(state, domains, dnsServer string, port int, enableSearch bool) error { + noSearch := "1" + if enableSearch { + noSearch = "0" + } + lines := buildAddCommandLine(keySupplementalMatchDomains, arraySymbol+domains) + lines += buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+noSearch) + lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer) + lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) + + addDomainCommand := buildCreateStateWithOperation(state, lines) + stdinCommands := wrapCommand(addDomainCommand) + + _, err := runSystemConfigCommand(stdinCommands) + if err != nil { + return fmt.Errorf("got error while applying state for domains %s, error: %s", domains, err) + } + return nil +} + +func (s *systemConfigurator) addDNSSetupForAll(dnsServer string, port int) error { + primaryServiceKey := s.getPrimaryService() + if primaryServiceKey == "" { + return fmt.Errorf("couldn't find the primary service key") + } + + err := s.addDNSSetup(getKeyWithInput(primaryServiceSetupKeyFormat, primaryServiceKey), dnsServer, port) + if err != nil { + return err + } + log.Infof("configured %s:%d as main DNS resolver for this peer", dnsServer, port) + s.primaryServiceID = primaryServiceKey + return nil +} + +func (s *systemConfigurator) getPrimaryService() string { + line := buildCommandLine("show", globalIPv4State, "") + stdinCommands := wrapCommand(line) + b, err := runSystemConfigCommand(stdinCommands) + if err != nil { + log.Error("got error while sending the command: ", err) + return "" + } + scanner := bufio.NewScanner(bytes.NewReader(b)) + for scanner.Scan() { + text := scanner.Text() + if strings.Contains(text, "PrimaryService") { + return strings.TrimSpace(strings.Split(text, ":")[1]) + } + } + return "" +} + +func (s *systemConfigurator) addDNSSetup(setupKey, dnsServer string, port int) error { + lines := buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+strconv.Itoa(0)) + lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer) + lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) + addDomainCommand := buildCreateStateWithOperation(setupKey, lines) + stdinCommands := wrapCommand(addDomainCommand) + _, err := runSystemConfigCommand(stdinCommands) + if err != nil { + return fmt.Errorf("got error while applying dns setup, error: %s", err) + } + return nil +} + +func getKeyWithInput(format, key string) string { + return fmt.Sprintf(format, key) +} + +func buildAddCommandLine(key, value string) string { + return buildCommandLine("d.add", key, value) +} + +func buildCommandLine(action, key, value string) string { + return fmt.Sprintf("%s %s %s\n", action, key, value) +} + +func wrapCommand(commands string) string { + return fmt.Sprintf("open\n%s\nquit\n", commands) +} + +func buildRemoveKeyOperation(key string) string { + return fmt.Sprintf("remove %s\n", key) +} + +func buildCreateStateWithOperation(state, commands string) string { + return buildWriteStateOperation("set", state, commands) +} + +func buildWriteStateOperation(operation, state, commands string) string { + return fmt.Sprintf("d.init\n%s %s\n%s\nset %s\n", operation, state, commands, state) +} + +func runSystemConfigCommand(command string) ([]byte, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader(command) + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("got error while running system configuration command: \"%s\", error: %s", command, err) + } + return out, nil +} diff --git a/client/internal/dns/host_linux.go b/client/internal/dns/host_linux.go new file mode 100644 index 000000000..ffb5098c7 --- /dev/null +++ b/client/internal/dns/host_linux.go @@ -0,0 +1,75 @@ +package dns + +import ( + "bufio" + "fmt" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "os" + "strings" +) + +const ( + defaultResolvConfPath = "/etc/resolv.conf" +) + +const ( + netbirdManager osManagerType = iota + fileManager + networkManager + systemdManager + resolvConfManager +) + +type osManagerType int + +func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { + osManager, err := getOSDNSManagerType() + if err != nil { + return nil, err + } + + log.Debugf("discovered mode is: %d", osManager) + switch osManager { + case networkManager: + return newNetworkManagerDbusConfigurator(wgInterface) + case systemdManager: + return newSystemdDbusConfigurator(wgInterface) + default: + return newFileConfigurator() + } +} + +func getOSDNSManagerType() (osManagerType, error) { + + file, err := os.Open(defaultResolvConfPath) + if err != nil { + return 0, fmt.Errorf("unable to open %s for checking owner, got error: %s", defaultResolvConfPath, err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + text := scanner.Text() + if len(text) == 0 { + continue + } + if text[0] != '#' { + return fileManager, nil + } + if strings.Contains(text, fileGeneratedResolvConfContentHeader) { + return netbirdManager, nil + } + if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { + log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion()) + return networkManager, nil + } + if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) { + return systemdManager, nil + } + if strings.Contains(text, "resolvconf") { + return resolvConfManager, nil + } + } + return fileManager, nil +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go new file mode 100644 index 000000000..e3f6cf34c --- /dev/null +++ b/client/internal/dns/host_windows.go @@ -0,0 +1,260 @@ +package dns + +import ( + "fmt" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows/registry" + "strings" +) + +const ( + dnsPolicyConfigMatchPath = "SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters\\DnsPolicyConfig\\NetBird-Match" + dnsPolicyConfigVersionKey = "Version" + dnsPolicyConfigVersionValue = 2 + dnsPolicyConfigNameKey = "Name" + dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers" + dnsPolicyConfigConfigOptionsKey = "ConfigOptions" + dnsPolicyConfigConfigOptionsValue = 0x8 +) + +const ( + interfaceConfigPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Interfaces" + interfaceConfigNameServerKey = "NameServer" + interfaceConfigSearchListKey = "SearchList" + tcpipParametersPath = "SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters" +) + +type registryConfigurator struct { + guid string + routingAll bool + existingSearchDomains []string +} + +func newHostManager(wgInterface *iface.WGIface) (hostManager, error) { + guid, err := wgInterface.GetInterfaceGUIDString() + if err != nil { + return nil, err + } + return ®istryConfigurator{ + guid: guid, + }, nil +} + +func (r *registryConfigurator) applyDNSConfig(config hostDNSConfig) error { + var err error + if config.routeAll { + err = r.addDNSSetupForAll(config.serverIP) + if err != nil { + return err + } + } else if r.routingAll { + err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey) + if err != nil { + return err + } + r.routingAll = false + log.Infof("removed %s as main DNS forwarder for this peer", config.serverIP) + } + + var ( + searchDomains []string + matchDomains []string + ) + + for _, dConf := range config.domains { + if !dConf.matchOnly { + searchDomains = append(searchDomains, dConf.domain) + } + matchDomains = append(matchDomains, "."+dConf.domain) + } + + if len(matchDomains) != 0 { + err = r.addDNSMatchPolicy(matchDomains, config.serverIP) + } else { + err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath) + } + if err != nil { + return err + } + + err = r.updateSearchDomains(searchDomains) + if err != nil { + return err + } + + return nil +} + +func (r *registryConfigurator) addDNSSetupForAll(ip string) error { + err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip) + if err != nil { + return fmt.Errorf("adding dns setup for all failed with error: %s", err) + } + r.routingAll = true + log.Infof("configured %s:53 as main DNS forwarder for this peer", ip) + return nil +} + +func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error { + _, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.QUERY_VALUE) + if err == nil { + err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath) + if err != nil { + return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err) + } + } + + regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", dnsPolicyConfigMatchPath, err) + } + + err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue) + if err != nil { + return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigVersionKey, err) + } + + err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains) + if err != nil { + return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigNameKey, err) + } + + err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip) + if err != nil { + return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigGenericDNSServersKey, err) + } + + err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue) + if err != nil { + return fmt.Errorf("unable to set registry value for %s, error: %s", dnsPolicyConfigConfigOptionsKey, err) + } + + log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains) + + return nil +} + +func (r *registryConfigurator) restoreHostDNS() error { + err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath) + if err != nil { + log.Error(err) + } + + return r.updateSearchDomains([]string{}) +} + +func (r *registryConfigurator) updateSearchDomains(domains []string) error { + value, err := getLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey) + if err != nil { + return fmt.Errorf("unable to get current search domains failed with error: %s", err) + } + + valueList := strings.Split(value, ",") + setExisting := false + if len(r.existingSearchDomains) == 0 { + r.existingSearchDomains = valueList + setExisting = true + } + + if len(domains) == 0 && setExisting { + log.Infof("added %d search domains to the registry. Domain list: %s", len(domains), domains) + return nil + } + + newList := append(r.existingSearchDomains, domains...) + + err = setLocalMachineRegistryKeyStringValue(tcpipParametersPath, interfaceConfigSearchListKey, strings.Join(newList, ",")) + if err != nil { + return fmt.Errorf("adding search domain failed with error: %s", err) + } + + log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains) + + return nil +} + +func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value string) error { + regKey, err := r.getInterfaceRegistryKey() + if err != nil { + return err + } + defer regKey.Close() + + err = regKey.SetStringValue(key, value) + if err != nil { + return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %s", key, value, err) + } + + return nil +} + +func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey string) error { + regKey, err := r.getInterfaceRegistryKey() + if err != nil { + return err + } + defer regKey.Close() + + err = regKey.DeleteValue(propertyKey) + if err != nil { + return fmt.Errorf("deleting registry key %s for interface failed with error: %s", propertyKey, err) + } + + return nil +} + +func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) { + var regKey registry.Key + + regKeyPath := interfaceConfigPath + "\\" + r.guid + + regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE) + if err != nil { + return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err) + } + + return regKey, nil +} + +func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE) + if err == nil { + k.Close() + err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath) + if err != nil { + return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %s", regKeyPath, err) + } + } + return nil +} + +func getLocalMachineRegistryKeyStringValue(keyPath, key string) (string, error) { + regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE) + if err != nil { + return "", fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err) + } + defer regKey.Close() + + val, _, err := regKey.GetStringValue(key) + if err != nil { + return "", fmt.Errorf("getting %s value for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, keyPath, err) + } + + return val, nil +} + +func setLocalMachineRegistryKeyStringValue(keyPath, key, value string) error { + regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("unable to open existing key from registry, key path: HKEY_LOCAL_MACHINE\\%s, error: %s", keyPath, err) + } + defer regKey.Close() + + err = regKey.SetStringValue(key, value) + if err != nil { + return fmt.Errorf("setting %s value %s for key path HKEY_LOCAL_MACHINE\\%s failed with error: %s", key, value, keyPath, err) + } + + return nil +} diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 741ab97b4..680fcc31a 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -1,6 +1,7 @@ package dns import ( + "fmt" "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" @@ -14,16 +15,16 @@ type localResolver struct { // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Tracef("received question: %#v\n", r.Question[0]) - response := d.lookupRecord(r) - if response == nil { - log.Debugf("got empty response for question: %#v\n", r.Question[0]) - return - } - + log.Debugf("received question: %#v\n", r.Question[0]) replyMessage := &dns.Msg{} replyMessage.SetReply(r) - replyMessage.Answer = append(replyMessage.Answer, response) + replyMessage.RecursionAvailable = true + replyMessage.Rcode = dns.RcodeSuccess + + response := d.lookupRecord(r) + if response != nil { + replyMessage.Answer = append(replyMessage.Answer, response) + } err := w.WriteMsg(replyMessage) if err != nil { @@ -32,7 +33,8 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR { - record, found := d.records.Load(r.Question[0].Name) + question := r.Question[0] + record, found := d.records.Load(buildRecordKey(question.Name, question.Qclass, question.Qtype)) if !found { return nil } @@ -46,7 +48,10 @@ func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error { return err } - d.records.Store(fullRecord.Header().Name, fullRecord) + fullRecord.Header().Rdlength = record.Len() + + header := fullRecord.Header() + d.records.Store(buildRecordKey(header.Name, header.Class, header.Rrtype), fullRecord) return nil } @@ -54,3 +59,8 @@ func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error { func (d *localResolver) deleteRecord(recordKey string) { d.records.Delete(dns.Fqdn(recordKey)) } + +func buildRecordKey(name string, class, qType uint16) string { + key := fmt.Sprintf("%s_%d_%d", name, class, qType) + return key +} diff --git a/client/internal/dns/local_test.go b/client/internal/dns/local_test.go index 79a57881b..db69d9ad8 100644 --- a/client/internal/dns/local_test.go +++ b/client/internal/dns/local_test.go @@ -64,7 +64,7 @@ func TestLocalResolver_ServeDNS(t *testing.T) { resolver.ServeDNS(responseWriter, testCase.inputMSG) - if responseMSG == nil { + if responseMSG == nil || len(responseMSG.Answer) == 0 { if testCase.responseShouldBeNil { return } diff --git a/client/internal/dns/network_manager_linux.go b/client/internal/dns/network_manager_linux.go new file mode 100644 index 000000000..955d54923 --- /dev/null +++ b/client/internal/dns/network_manager_linux.go @@ -0,0 +1,295 @@ +package dns + +import ( + "context" + "encoding/binary" + "fmt" + "github.com/godbus/dbus/v5" + "github.com/hashicorp/go-version" + "github.com/miekg/dns" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "net/netip" + "regexp" + "time" +) + +const ( + networkManagerDest = "org.freedesktop.NetworkManager" + networkManagerDbusObjectNode = "/org/freedesktop/NetworkManager" + networkManagerDbusDNSManagerInterface = "org.freedesktop.NetworkManager.DnsManager" + networkManagerDbusDNSManagerObjectNode = networkManagerDbusObjectNode + "/DnsManager" + networkManagerDbusDNSManagerModeProperty = networkManagerDbusDNSManagerInterface + ".Mode" + networkManagerDbusDNSManagerRcManagerProperty = networkManagerDbusDNSManagerInterface + ".RcManager" + networkManagerDbusVersionProperty = "org.freedesktop.NetworkManager.Version" + networkManagerDbusGetDeviceByIPIfaceMethod = networkManagerDest + ".GetDeviceByIpIface" + networkManagerDbusDeviceInterface = "org.freedesktop.NetworkManager.Device" + networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection" + networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply" + networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete" + networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0 + networkManagerDbusIPv4Key = "ipv4" + networkManagerDbusIPv6Key = "ipv6" + networkManagerDbusDNSKey = "dns" + networkManagerDbusDNSSearchKey = "dns-search" + networkManagerDbusDNSPriorityKey = "dns-priority" + + // dns priority doc https://wiki.gnome.org/Projects/NetworkManager/DNS + networkManagerDbusPrimaryDNSPriority int32 = -500 + networkManagerDbusWithMatchDomainPriority int32 = 0 + networkManagerDbusSearchDomainOnlyPriority int32 = 50 + supportedNetworkManagerVersionConstraint = ">= 1.16, < 1.28" +) + +type networkManagerDbusConfigurator struct { + dbusLinkObject dbus.ObjectPath + routingAll bool +} + +// the types below are based on dbus specification, each field is mapped to a dbus type +// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types +// see https://networkmanager.dev/docs/api/latest/gdbus-org.freedesktop.NetworkManager.Device.html on Network Manager input types + +// networkManagerConnSettings maps to a (a{sa{sv}}) dbus output from GetAppliedConnection and input for Reapply methods +type networkManagerConnSettings map[string]map[string]dbus.Variant + +// networkManagerConfigVersion maps to a (t) dbus output from GetAppliedConnection and input for Reapply methods +type networkManagerConfigVersion uint64 + +// networkManagerConfigBehavior maps to a (u) dbus input for GetAppliedConnection and Reapply methods +type networkManagerConfigBehavior uint32 + +// cleanDeprecatedSettings cleans deprecated settings that still returned by +// the GetAppliedConnection methods but can't be reApplied +func (s networkManagerConnSettings) cleanDeprecatedSettings() { + for _, key := range []string{"addresses", "routes"} { + delete(s[networkManagerDbusIPv4Key], key) + delete(s[networkManagerDbusIPv6Key], key) + } +} + +func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { + obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) + if err != nil { + return nil, err + } + defer closeConn() + var s string + err = obj.Call(networkManagerDbusGetDeviceByIPIfaceMethod, dbusDefaultFlag, wgInterface.GetName()).Store(&s) + if err != nil { + return nil, err + } + + log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface.GetName()) + + return &networkManagerDbusConfigurator{ + dbusLinkObject: dbus.ObjectPath(s), + }, nil +} + +func (n *networkManagerDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { + connSettings, configVersion, err := n.getAppliedConnectionSettings() + if err != nil { + return fmt.Errorf("got an error while retrieving the applied connection settings, error: %s", err) + } + + connSettings.cleanDeprecatedSettings() + + dnsIP := netip.MustParseAddr(config.serverIP) + convDNSIP := binary.LittleEndian.Uint32(dnsIP.AsSlice()) + connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSKey] = dbus.MakeVariant([]uint32{convDNSIP}) + var ( + searchDomains []string + matchDomains []string + ) + for _, dConf := range config.domains { + if dConf.matchOnly { + matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.domain)) + continue + } + searchDomains = append(searchDomains, dns.Fqdn(dConf.domain)) + } + + newDomainList := append(searchDomains, matchDomains...) + + priority := networkManagerDbusSearchDomainOnlyPriority + switch { + case config.routeAll: + priority = networkManagerDbusPrimaryDNSPriority + newDomainList = append(newDomainList, "~.") + if !n.routingAll { + log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + } + case len(matchDomains) > 0: + priority = networkManagerDbusWithMatchDomainPriority + } + + if priority != networkManagerDbusPrimaryDNSPriority && n.routingAll { + log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + n.routingAll = false + } + + connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) + connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) + + log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) + err = n.reApplyConnectionSettings(connSettings, configVersion) + if err != nil { + return fmt.Errorf("got an error while reapplying the connection with new settings, error: %s", err) + } + return nil +} + +func (n *networkManagerDbusConfigurator) restoreHostDNS() error { + // once the interface is gone network manager cleans all config associated with it + return n.deleteConnectionSettings() +} + +func (n *networkManagerDbusConfigurator) getAppliedConnectionSettings() (networkManagerConnSettings, networkManagerConfigVersion, error) { + obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) + if err != nil { + return nil, 0, fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err) + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + var ( + connSettings networkManagerConnSettings + configVersion networkManagerConfigVersion + ) + + err = obj.CallWithContext(ctx, networkManagerDbusDeviceGetAppliedConnectionMethod, dbusDefaultFlag, + networkManagerDbusDefaultBehaviorFlag).Store(&connSettings, &configVersion) + if err != nil { + return nil, 0, fmt.Errorf("got error while calling GetAppliedConnection method with context, err: %s", err) + } + + return connSettings, configVersion, nil +} + +func (n *networkManagerDbusConfigurator) reApplyConnectionSettings(connSettings networkManagerConnSettings, configVersion networkManagerConfigVersion) error { + obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err) + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + err = obj.CallWithContext(ctx, networkManagerDbusDeviceReapplyMethod, dbusDefaultFlag, + connSettings, configVersion, networkManagerDbusDefaultBehaviorFlag).Store() + if err != nil { + return fmt.Errorf("got error while calling ReApply method with context, err: %s", err) + } + + return nil +} + +func (n *networkManagerDbusConfigurator) deleteConnectionSettings() error { + obj, closeConn, err := getDbusObject(networkManagerDest, n.dbusLinkObject) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the applied connection settings, err: %s", err) + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + err = obj.CallWithContext(ctx, networkManagerDbusDeviceDeleteMethod, dbusDefaultFlag).Store() + if err != nil { + return fmt.Errorf("got error while calling delete method with context, err: %s", err) + } + + return nil +} + +func isNetworkManagerSupported() bool { + return isNetworkManagerSupportedVersion() && isNetworkManagerSupportedMode() +} + +func isNetworkManagerSupportedMode() bool { + var mode string + err := getNetworkManagerDNSProperty(networkManagerDbusDNSManagerModeProperty, &mode) + if err != nil { + log.Error(err) + return false + } + switch mode { + case "dnsmasq", "unbound", "systemd-resolved": + return true + default: + var rcManager string + err = getNetworkManagerDNSProperty(networkManagerDbusDNSManagerRcManagerProperty, &rcManager) + if err != nil { + log.Error(err) + return false + } + if rcManager == "unmanaged" { + return false + } + } + return true +} + +func getNetworkManagerDNSProperty(property string, store any) error { + obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusDNSManagerObjectNode) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the network manager dns manager object, error: %s", err) + } + defer closeConn() + + v, e := obj.GetProperty(property) + if e != nil { + return fmt.Errorf("got an error getting property %s: %v", property, e) + } + + return v.Store(store) +} + +func isNetworkManagerSupportedVersion() bool { + obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) + if err != nil { + log.Errorf("got error while attempting to get the network manager object, err: %s", err) + return false + } + + defer closeConn() + + value, err := obj.GetProperty(networkManagerDbusVersionProperty) + if err != nil { + log.Errorf("unable to retrieve network manager mode, got error: %s", err) + return false + } + versionValue, err := parseVersion(value.Value().(string)) + if err != nil { + return false + } + + constraints, err := version.NewConstraint(supportedNetworkManagerVersionConstraint) + if err != nil { + return false + } + + return constraints.Check(versionValue) +} + +func parseVersion(inputVersion string) (*version.Version, error) { + reg, err := regexp.Compile(version.SemverRegexpRaw) + if err != nil { + return nil, err + } + + if inputVersion == "" || !reg.MatchString(inputVersion) { + return nil, fmt.Errorf("couldn't parse the provided version: Not SemVer") + } + + verObj, err := version.NewVersion(inputVersion) + if err != nil { + return nil, err + } + + return verObj, nil +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 67f3788ea..91a38cd4a 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,14 +5,19 @@ import ( "fmt" "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/iface" log "github.com/sirupsen/logrus" + "net" + "net/netip" + "runtime" "sync" "time" ) const ( - port = 5053 - defaultIP = "0.0.0.0" + port = 53 + customPort = 5053 + defaultIP = "127.0.0.1" ) // Server is a dns server interface @@ -31,8 +36,12 @@ type DefaultServer struct { dnsMux *dns.ServeMux dnsMuxMap registrationMap localResolver *localResolver + wgInterface *iface.WGIface + hostManager hostManager updateSerial uint64 listenerIsRunning bool + runtimePort int + runtimeIP string } type registrationMap map[string]struct{} @@ -43,11 +52,15 @@ type muxUpdate struct { } // NewDefaultServer returns a new dns server -func NewDefaultServer(ctx context.Context) *DefaultServer { +func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface) (*DefaultServer, error) { mux := dns.NewServeMux() + listenIP := defaultIP + if runtime.GOOS != "darwin" && wgInterface != nil { + listenIP = wgInterface.GetAddress().IP.String() + } dnsServer := &dns.Server{ - Addr: fmt.Sprintf("%s:%d", defaultIP, port), + Addr: fmt.Sprintf("%s:%d", listenIP, port), Net: "udp", Handler: mux, UDPSize: 65535, @@ -55,7 +68,7 @@ func NewDefaultServer(ctx context.Context) *DefaultServer { ctx, stop := context.WithCancel(ctx) - return &DefaultServer{ + defaultServer := &DefaultServer{ ctx: ctx, stop: stop, server: dnsServer, @@ -64,18 +77,44 @@ func NewDefaultServer(ctx context.Context) *DefaultServer { localResolver: &localResolver{ registeredMap: make(registrationMap), }, + wgInterface: wgInterface, + runtimePort: port, + runtimeIP: listenIP, } + + hostmanager, err := newHostManager(wgInterface) + if err != nil { + return nil, err + } + defaultServer.hostManager = hostmanager + return defaultServer, err } // Start runs the listener in a go routine func (s *DefaultServer) Start() { - log.Debugf("starting dns on %s:%d", defaultIP, port) + s.runtimePort = port + udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(s.server.Addr)) + probeListener, err := net.ListenUDP("udp", udpAddr) + if err != nil { + log.Warnf("using a custom port for dns server") + s.runtimePort = customPort + s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, customPort) + } else { + err = probeListener.Close() + if err != nil { + log.Errorf("got an error closing the probe listener, error: %s", err) + } + } + + log.Debugf("starting dns on %s", s.server.Addr) + go func() { s.setListenerStatus(true) defer s.setListenerStatus(false) - err := s.server.ListenAndServe() + + err = s.server.ListenAndServe() if err != nil { - log.Errorf("dns server returned an error: %v", err) + log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err) } }() } @@ -86,9 +125,16 @@ func (s *DefaultServer) setListenerStatus(running bool) { // Stop stops the server func (s *DefaultServer) Stop() { + s.mux.Lock() + defer s.mux.Unlock() s.stop() - err := s.stopListener() + err := s.hostManager.restoreHostDNS() + if err != nil { + log.Error(err) + } + + err = s.stopListener() if err != nil { log.Error(err) } @@ -148,6 +194,11 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro s.updateMux(muxUpdates) s.updateLocalResolver(localRecords) + err = s.hostManager.applyDNSConfig(dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)) + if err != nil { + log.Error(err) + } + s.updateSerial = serial return nil @@ -170,7 +221,12 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) }) for _, record := range customZone.Records { - localRecords[record.Name] = record + var class uint16 = dns.ClassINET + if record.Class != nbdns.DefaultClass { + return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class) + } + key := buildRecordKey(record.Name, class, uint16(record.Type)) + localRecords[key] = record } } return muxUpdates, localRecords, nil diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 6bbfef507..b0b8cd1ec 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -3,6 +3,7 @@ package dns import ( "context" "fmt" + "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" "net" "net/netip" @@ -74,12 +75,12 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}, nbdns.RootZone: struct{}{}}, - expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { name: "New Config Should Succeed", initLocalMap: registrationMap{"netbird.cloud": struct{}{}}, - initUpstreamMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + initUpstreamMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, initSerial: 0, inputSerial: 1, inputUpdate: nbdns.Config{ @@ -98,7 +99,7 @@ func TestUpdateDNSServer(t *testing.T) { }, }, expectedUpstreamMap: registrationMap{"netbird.io": struct{}{}, "netbird.cloud": struct{}{}}, - expectedLocalMap: registrationMap{zoneRecords[0].Name: struct{}{}}, + expectedLocalMap: registrationMap{buildRecordKey(zoneRecords[0].Name, 1, 1): struct{}{}}, }, { name: "Smaller Config Serial Should Be Skipped", @@ -188,12 +189,14 @@ func TestUpdateDNSServer(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - ctx := context.Background() - dnsServer := NewDefaultServer(ctx) + dnsServer := getDefaultServerWithNoHostManager("127.0.0.1") + + dnsServer.hostManager = newNoopHostMocker() dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.localResolver.registeredMap = testCase.initLocalMap dnsServer.updateSerial = testCase.initSerial + // pretend we are running dnsServer.listenerIsRunning = true err := dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) @@ -230,13 +233,15 @@ func TestUpdateDNSServer(t *testing.T) { } func TestDNSServerStartStop(t *testing.T) { - ctx := context.Background() - dnsServer := NewDefaultServer(ctx) + dnsServer := getDefaultServerWithNoHostManager("127.0.0.1") + if runtime.GOOS == "windows" && os.Getenv("CI") == "true" { // todo review why this test is not working only on github actions workflows t.Skip("skipping test in Windows CI workflows.") } + dnsServer.hostManager = newNoopHostMocker() + dnsServer.Start() err := dnsServer.localResolver.registerRecord(zoneRecords[0]) @@ -276,10 +281,40 @@ func TestDNSServerStartStop(t *testing.T) { } dnsServer.Stop() - ctx, cancel := context.WithTimeout(ctx, time.Second*1) + ctx, cancel := context.WithTimeout(context.TODO(), time.Second*1) defer cancel() _, err = resolver.LookupHost(ctx, zoneRecords[0].Name) if err == nil { t.Fatalf("we should encounter an error when querying a stopped server") } } + +func getDefaultServerWithNoHostManager(ip string) *DefaultServer { + mux := dns.NewServeMux() + listenIP := defaultIP + if ip != "" { + listenIP = ip + } + + dnsServer := &dns.Server{ + Addr: fmt.Sprintf("%s:%d", ip, port), + Net: "udp", + Handler: mux, + UDPSize: 65535, + } + + ctx, stop := context.WithCancel(context.TODO()) + + return &DefaultServer{ + ctx: ctx, + stop: stop, + server: dnsServer, + dnsMux: mux, + dnsMuxMap: make(registrationMap), + localResolver: &localResolver{ + registeredMap: make(registrationMap), + }, + runtimePort: port, + runtimeIP: listenIP, + } +} diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go new file mode 100644 index 000000000..54a73968a --- /dev/null +++ b/client/internal/dns/systemd_linux.go @@ -0,0 +1,185 @@ +package dns + +import ( + "context" + "fmt" + "github.com/godbus/dbus/v5" + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/iface" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + "net" + "net/netip" + "time" +) + +const ( + systemdDbusManagerInterface = "org.freedesktop.resolve1.Manager" + systemdResolvedDest = "org.freedesktop.resolve1" + systemdDbusObjectNode = "/org/freedesktop/resolve1" + systemdDbusGetLinkMethod = systemdDbusManagerInterface + ".GetLink" + systemdDbusFlushCachesMethod = systemdDbusManagerInterface + ".FlushCaches" + systemdDbusLinkInterface = "org.freedesktop.resolve1.Link" + systemdDbusRevertMethodSuffix = systemdDbusLinkInterface + ".Revert" + systemdDbusSetDNSMethodSuffix = systemdDbusLinkInterface + ".SetDNS" + systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" + systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" +) + +type systemdDbusConfigurator struct { + dbusLinkObject dbus.ObjectPath + routingAll bool +} + +// the types below are based on dbus specification, each field is mapped to a dbus type +// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types +// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types +// systemdDbusDNSInput maps to a (iay) dbus input for SetDNS method +type systemdDbusDNSInput struct { + Family int32 + Address []byte +} + +// systemdDbusLinkDomainsInput maps to a (sb) dbus input for SetDomains method +type systemdDbusLinkDomainsInput struct { + Domain string + MatchOnly bool +} + +func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) { + iface, err := net.InterfaceByName(wgInterface.GetName()) + if err != nil { + return nil, err + } + + obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) + if err != nil { + return nil, err + } + defer closeConn() + + var s string + err = obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, iface.Index).Store(&s) + if err != nil { + return nil, err + } + + log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index) + + return &systemdDbusConfigurator{ + dbusLinkObject: dbus.ObjectPath(s), + }, nil +} + +func (s *systemdDbusConfigurator) applyDNSConfig(config hostDNSConfig) error { + parsedIP := netip.MustParseAddr(config.serverIP).As4() + defaultLinkInput := systemdDbusDNSInput{ + Family: unix.AF_INET, + Address: parsedIP[:], + } + err := s.callLinkMethod(systemdDbusSetDNSMethodSuffix, []systemdDbusDNSInput{defaultLinkInput}) + if err != nil { + return fmt.Errorf("setting the interface DNS server %s:%d failed with error: %s", config.serverIP, config.serverPort, err) + } + + var ( + searchDomains []string + matchDomains []string + domainsInput []systemdDbusLinkDomainsInput + ) + for _, dConf := range config.domains { + domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ + Domain: dns.Fqdn(dConf.domain), + MatchOnly: dConf.matchOnly, + }) + + if dConf.matchOnly { + matchDomains = append(matchDomains, dConf.domain) + continue + } + searchDomains = append(searchDomains, dConf.domain) + } + + if config.routeAll { + log.Infof("configured %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true) + if err != nil { + return fmt.Errorf("setting link as default dns router, failed with error: %s", err) + } + domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{ + Domain: nbdns.RootZone, + MatchOnly: true, + }) + s.routingAll = true + } else if s.routingAll { + log.Infof("removing %s:%d as main DNS forwarder for this peer", config.serverIP, config.serverPort) + } + + log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) + err = s.setDomainsForInterface(domainsInput) + if err != nil { + log.Error(err) + } + return nil +} + +func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdDbusLinkDomainsInput) error { + err := s.callLinkMethod(systemdDbusSetDomainsMethodSuffix, domainsInput) + if err != nil { + return fmt.Errorf("setting domains configuration failed with error: %s", err) + } + return s.flushCaches() +} + +func (s *systemdDbusConfigurator) restoreHostDNS() error { + log.Infof("reverting link settings and flushing cache") + if !isDbusListenerRunning(systemdResolvedDest, s.dbusLinkObject) { + return nil + } + err := s.callLinkMethod(systemdDbusRevertMethodSuffix, nil) + if err != nil { + return fmt.Errorf("unable to revert link configuration, got error: %s", err) + } + return s.flushCaches() +} + +func (s *systemdDbusConfigurator) flushCaches() error { + obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the object %s, err: %s", systemdDbusObjectNode, err) + } + defer closeConn() + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + err = obj.CallWithContext(ctx, systemdDbusFlushCachesMethod, dbusDefaultFlag).Store() + if err != nil { + return fmt.Errorf("got error while calling the FlushCaches method with context, err: %s", err) + } + + return nil +} + +func (s *systemdDbusConfigurator) callLinkMethod(method string, value any) error { + obj, closeConn, err := getDbusObject(systemdResolvedDest, s.dbusLinkObject) + if err != nil { + return fmt.Errorf("got error while attempting to retrieve the object, err: %s", err) + } + defer closeConn() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + if value != nil { + err = obj.CallWithContext(ctx, method, dbusDefaultFlag, value).Store() + } else { + err = obj.CallWithContext(ctx, method, dbusDefaultFlag).Store() + } + + if err != nil { + return fmt.Errorf("got error while calling command with context, err: %s", err) + } + + return nil +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index fcc8bc685..e2e61203c 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -21,7 +21,7 @@ type upstreamResolver struct { // ServeDNS handles a DNS request func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Tracef("received an upstream question: %#v", r.Question[0]) + log.Debugf("received an upstream question: %#v", r.Question[0]) select { case <-u.parentCTX.Done(): diff --git a/client/internal/engine.go b/client/internal/engine.go index 82425e62a..7e64c003d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -139,7 +139,6 @@ func NewEngine( networkSerial: 0, sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, - dnsServer: dns.NewDefaultServer(ctx), } } @@ -261,6 +260,14 @@ func (e *Engine) Start() error { e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder) + if e.dnsServer == nil { + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface) + if err != nil { + return err + } + e.dnsServer = dnsServer + } + e.receiveSignalEvents() e.receiveManagementEvents() diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 56d9eb66f..9e80f144d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -202,6 +202,9 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { }, nbstatus.NewRecorder()) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) + engine.dnsServer = &dns.MockServer{ + UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, + } type testCase struct { name string @@ -551,6 +554,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { } engine.routeManager = mockRouteManager + engine.dnsServer = &dns.MockServer{} defer func() { exitErr := engine.Stop() @@ -797,6 +801,7 @@ func TestEngine_MultiplePeers(t *testing.T) { t.Errorf("unable to create the engine for peer %d with error %v", j, err) return } + engine.dnsServer = &dns.MockServer{} mu.Lock() defer mu.Unlock() err = engine.Start() diff --git a/dns/dns.go b/dns/dns.go index a09e4b5df..16ebd1d96 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/miekg/dns" "golang.org/x/net/idna" + "net" "regexp" "strings" ) @@ -60,6 +61,30 @@ func (s SimpleRecord) String() string { return fmt.Sprintf("%s %d %s %s %s", fqdn, s.TTL, s.Class, dns.Type(s.Type).String(), s.RData) } +// Len returns the length of the RData field, based on its type +func (s SimpleRecord) Len() uint16 { + emptyString := s.RData == "" + switch s.Type { + case 1: + if emptyString { + return 0 + } + return net.IPv4len + case 5: + if emptyString || s.RData == "." { + return 1 + } + return uint16(len(s.RData) + 1) + case 28: + if emptyString { + return 0 + } + return net.IPv6len + default: + return 0 + } +} + // GetParsedDomainLabel returns a domain label with max 59 characters, // parsed for old Hosts.txt requirements, and converted to ASCII and lowercase func GetParsedDomainLabel(name string) (string, error) { diff --git a/go.mod b/go.mod index e6c0528d2..b7f92bc73 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/eko/gocache/v3 v3.1.1 github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 + github.com/godbus/dbus/v5 v5.1.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/hashicorp/go-version v1.6.0 github.com/libp2p/go-netroute v0.2.0 @@ -75,7 +76,6 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-stack/stack v1.8.0 // indirect - github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/gopacket v1.1.19 // indirect diff --git a/go.sum b/go.sum index 3a4338ff7..707d7c808 100644 --- a/go.sum +++ b/go.sum @@ -223,8 +223,9 @@ github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/godbus/dbus/v5 v5.0.4 h1:9349emZab16e7zQvpmsbtjc18ykshndd8y2PG3sgJbA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= +github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= diff --git a/iface/iface.go b/iface/iface.go index bdfa78abb..d75c4db86 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -73,13 +73,15 @@ func parseAddress(address string) (WGAddress, error) { func (w *WGIface) Close() error { w.mu.Lock() defer w.mu.Unlock() - + if w.Interface == nil { + return nil + } err := w.Interface.Close() if err != nil { return err } - if runtime.GOOS == "darwin" { + if runtime.GOOS != "windows" { sockPath := "/var/run/wireguard/" + w.Name + ".sock" if _, statErr := os.Stat(sockPath); statErr == nil { statErr = os.Remove(sockPath) diff --git a/iface/iface_unix.go b/iface/iface_unix.go index 66d316997..ebac5d8a1 100644 --- a/iface/iface_unix.go +++ b/iface/iface_unix.go @@ -75,3 +75,8 @@ func (w *WGIface) UpdateAddr(newAddr string) error { w.Address = addr return w.assignAddr() } + +// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only +func (w *WGIface) GetInterfaceGUIDString() (string, error) { + return "", nil +} diff --git a/iface/iface_windows.go b/iface/iface_windows.go index d38cd3dc4..5c16916b9 100644 --- a/iface/iface_windows.go +++ b/iface/iface_windows.go @@ -58,6 +58,20 @@ func (w *WGIface) UpdateAddr(newAddr string) error { return w.assignAddr(luid) } +// GetInterfaceGUIDString returns an interface GUID string +func (w *WGIface) GetInterfaceGUIDString() (string, error) { + if w.Interface == nil { + return "", fmt.Errorf("interface has not been initialized yet") + } + windowsDevice := w.Interface.(*driver.Adapter) + luid := windowsDevice.LUID() + guid, err := luid.GUID() + if err != nil { + return "", err + } + return guid.String(), nil +} + // WireguardModuleIsLoaded check if we can load wireguard mod (linux only) func WireguardModuleIsLoaded() bool { return false