From 7e6beee7f6bb18865c544c959f8dd5defe4b3762 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 19 Feb 2025 19:13:45 +0100 Subject: [PATCH] [management] optimize test execution (#3204) --- .github/workflows/golang-test-linux.yml | 150 ++- .golangci.yaml | 2 +- management/client/client_test.go | 5 +- management/server/dns_test.go | 14 +- management/server/group_test.go | 6 +- management/server/management_suite_test.go | 13 - management/server/management_test.go | 1046 ++++++++++++-------- management/server/nameserver_test.go | 13 +- management/server/route_test.go | 5 +- management/server/store/sql_store_test.go | 244 +++-- management/server/store/store.go | 152 ++- management/server/testutil/store.go | 4 +- management/server/types/user.go | 2 +- relay/client/dialer/ws/ws.go | 2 +- relay/server/listener/ws/conn.go | 2 +- relay/server/listener/ws/listener.go | 2 +- 16 files changed, 1019 insertions(+), 643 deletions(-) delete mode 100644 management/server/management_suite_test.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index b1ec0a896..efe1a2654 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -1,4 +1,4 @@ -name: Test Code Linux +name: Linux on: push: @@ -12,6 +12,7 @@ concurrency: jobs: build-cache: + name: "Build Cache" runs-on: ubuntu-22.04 outputs: management: ${{ steps.filter.outputs.management }} @@ -47,7 +48,6 @@ jobs: key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} - - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' @@ -98,6 +98,7 @@ jobs: run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 . test: + name: "Client / Unit" needs: [build-cache] strategy: fail-fast: false @@ -143,9 +144,116 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay) + + test_relay: + name: "Relay / Unit" + needs: [build-cache] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Test + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + go test \ + -exec 'sudo' \ + -timeout 10m ./signal/... + + test_signal: + name: "Signal / Unit" + needs: [build-cache] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Test + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + go test \ + -exec 'sudo' \ + -timeout 10m ./signal/... test_management: + name: "Management / Unit" needs: [ build-cache ] strategy: fail-fast: false @@ -203,9 +311,15 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + NETBIRD_STORE_ENGINE=${{ matrix.store }} \ + go test -tags=devcert \ + -exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \ + -timeout 10m ./management/... benchmark: + name: "Management / Benchmark" needs: [ build-cache ] if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }} strategy: @@ -264,9 +378,15 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ + go test -tags devcert -run=^$ -bench=. \ + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ + -timeout 20m ./... api_benchmark: + name: "Management / Benchmark (API)" needs: [ build-cache ] if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }} strategy: @@ -323,11 +443,19 @@ jobs: - name: download mysql image if: matrix.store == 'mysql' run: docker pull mlsmaycon/warmed-mysql:8 - + - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management) + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ + go test -tags=benchmark \ + -run=^$ \ + -bench=. \ + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ + -timeout 20m ./management/... api_integration_test: + name: "Management / Integration" needs: [ build-cache ] if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }} strategy: @@ -375,9 +503,15 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management) + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ + go test -tags=integration \ + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ + -timeout 10m ./management/... test_client_on_docker: + name: "Client (Docker) / Unit" needs: [ build-cache ] runs-on: ubuntu-20.04 steps: diff --git a/.golangci.yaml b/.golangci.yaml index 44b03d0e1..461677c2e 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -103,7 +103,7 @@ linters: - predeclared # predeclared finds code that shadows one of Go's predeclared identifiers - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed - - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers. + # - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers. - wastedassign # wastedassign finds wasted assignment statements issues: # Maximum count of issues with the same text. diff --git a/management/client/client_test.go b/management/client/client_test.go index 3e498a5ea..b4ee58298 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -258,8 +258,11 @@ func TestClient_Sync(t *testing.T) { ch := make(chan *mgmtProto.SyncResponse, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { - err = client.Sync(context.Background(), info, func(msg *mgmtProto.SyncResponse) error { + err = client.Sync(ctx, info, func(msg *mgmtProto.SyncResponse) error { ch <- msg return nil }) diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 6fb9f6a29..c40f62324 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -42,7 +42,7 @@ func TestGetDNSSettings(t *testing.T) { account, err := initTestDNSAccount(t, am) if err != nil { - t.Fatal("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) @@ -124,12 +124,12 @@ func TestSaveDNSSettings(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { am, err := createDNSManager(t) if err != nil { - t.Error("failed to create account manager") + t.Fatalf("failed to create account manager") } account, err := initTestDNSAccount(t, am) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %v", err) } err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) @@ -156,22 +156,22 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { am, err := createDNSManager(t) if err != nil { - t.Error("failed to create account manager") + t.Fatalf("failed to create account manager: %s", err) } account, err := initTestDNSAccount(t, am) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } peer1, err := account.FindPeerByPubKey(dnsPeer1Key) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } peer2, err := account.FindPeerByPubKey(dnsPeer2Key) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) diff --git a/management/server/group_test.go b/management/server/group_test.go index cc90f187b..b21b5e834 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -29,7 +29,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { _, account, err := initTestGroupAccount(am) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } for _, group := range account.Groups { group.Issued = types.GroupIssuedIntegration @@ -59,12 +59,12 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { func TestDefaultAccountManager_DeleteGroup(t *testing.T) { am, err := createManager(t) if err != nil { - t.Error("failed to create account manager") + t.Fatalf("failed to create account manager: %s", err) } _, account, err := initTestGroupAccount(am) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } testCases := []struct { diff --git a/management/server/management_suite_test.go b/management/server/management_suite_test.go deleted file mode 100644 index cc99624a0..000000000 --- a/management/server/management_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package server_test - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "testing" -) - -func TestManagement(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Management Service Suite") -} diff --git a/management/server/management_test.go b/management/server/management_test.go index 43a6e40d5..1b91b3447 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -6,13 +6,13 @@ import ( "net" "os" "runtime" - sync2 "sync" + "sync" + "testing" "time" pb "github.com/golang/protobuf/proto" //nolint - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -30,424 +30,77 @@ import ( const ( ValidSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - AccountKey = "bf1c8084-ba50-4ce7-9439-34653001fc3b" ) -var _ = Describe("Management service", func() { - var ( - addr string - s *grpc.Server - dataDir string - client mgmtProto.ManagementServiceClient - serverPubKey wgtypes.Key - conn *grpc.ClientConn - ) - - BeforeEach(func() { - level, _ := log.ParseLevel("Debug") - log.SetLevel(level) - var err error - dataDir, err = os.MkdirTemp("", "netbird_mgmt_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - - var listener net.Listener - - config := &server.Config{} - _, err = util.ReadJson("testdata/management.json", config) - Expect(err).NotTo(HaveOccurred()) - config.Datadir = dataDir - - s, listener = startServer(config, dataDir, "testdata/store.sql") - addr = listener.Addr().String() - client, conn = createRawClient(addr) - - // s public key - resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) - Expect(err).NotTo(HaveOccurred()) - serverPubKey, err = wgtypes.ParseKey(resp.Key) - Expect(err).NotTo(HaveOccurred()) - }) - - AfterEach(func() { - s.Stop() - err := conn.Close() - Expect(err).NotTo(HaveOccurred()) - os.RemoveAll(dataDir) - }) - - Context("when calling IsHealthy endpoint", func() { - Specify("a non-error result is returned", func() { - healthy, err := client.IsHealthy(context.TODO(), &mgmtProto.Empty{}) - - Expect(err).NotTo(HaveOccurred()) - Expect(healthy).ToNot(BeNil()) - }) - }) - - Context("when calling Sync endpoint", func() { - Context("when there is a new peer registered", func() { - Specify("a proper configuration is returned", func() { - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - - syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} - encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq) - Expect(err).NotTo(HaveOccurred()) - - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - Expect(err).NotTo(HaveOccurred()) - - resp := &mgmtProto.SyncResponse{} - err = encryption.DecryptMessage(serverPubKey, key, encryptedResponse.Body, resp) - Expect(err).NotTo(HaveOccurred()) - - expectedSignalConfig := &mgmtProto.HostConfig{ - Uri: "signal.netbird.io:10000", - Protocol: mgmtProto.HostConfig_HTTP, - } - expectedStunsConfig := &mgmtProto.HostConfig{ - Uri: "stun:stun.netbird.io:3468", - Protocol: mgmtProto.HostConfig_UDP, - } - expectedTRUNHost := &mgmtProto.HostConfig{ - Uri: "turn:stun.netbird.io:3468", - Protocol: mgmtProto.HostConfig_UDP, - } - - Expect(resp.NetbirdConfig.Signal).To(BeEquivalentTo(expectedSignalConfig)) - Expect(resp.NetbirdConfig.Stuns).To(ConsistOf(expectedStunsConfig)) - // TURN validation is special because credentials are dynamically generated - Expect(resp.NetbirdConfig.Turns).To(HaveLen(1)) - actualTURN := resp.NetbirdConfig.Turns[0] - Expect(len(actualTURN.User) > 0).To(BeTrue()) - Expect(actualTURN.HostConfig).To(BeEquivalentTo(expectedTRUNHost)) - Expect(len(resp.NetworkMap.OfflinePeers) == 0).To(BeTrue()) - }) - }) - - Context("when there are 3 peers registered under one account", func() { - Specify("a list containing other 2 peers is returned", func() { - key, _ := wgtypes.GenerateKey() - key1, _ := wgtypes.GenerateKey() - key2, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - loginPeerWithValidSetupKey(serverPubKey, key1, client) - loginPeerWithValidSetupKey(serverPubKey, key2, client) - - messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - Expect(err).NotTo(HaveOccurred()) - decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - - resp := &mgmtProto.SyncResponse{} - err = pb.Unmarshal(decryptedBytes, resp) - Expect(err).NotTo(HaveOccurred()) - - Expect(resp.GetRemotePeers()).To(HaveLen(2)) - peers := []string{resp.GetRemotePeers()[0].WgPubKey, resp.GetRemotePeers()[1].WgPubKey} - Expect(peers).To(ContainElements(key1.PublicKey().String(), key2.PublicKey().String())) - }) - }) - - Context("when there is a new peer registered", func() { - Specify("an update is returned", func() { - // register only a single peer - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - - messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - - // after the initial sync call we have 0 peer updates - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - Expect(err).NotTo(HaveOccurred()) - decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - resp := &mgmtProto.SyncResponse{} - err = pb.Unmarshal(decryptedBytes, resp) - Expect(resp.GetRemotePeers()).To(HaveLen(0)) - - wg := sync2.WaitGroup{} - wg.Add(1) - - // continue listening on updates for a peer - go func() { - err = sync.RecvMsg(encryptedResponse) - - decryptedBytes, err = encryption.Decrypt(encryptedResponse.Body, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - resp = &mgmtProto.SyncResponse{} - err = pb.Unmarshal(decryptedBytes, resp) - wg.Done() - }() - - // register a new peer - key1, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key1, client) - - wg.Wait() - - Expect(err).NotTo(HaveOccurred()) - Expect(resp.GetRemotePeers()).To(HaveLen(1)) - Expect(resp.GetRemotePeers()[0].WgPubKey).To(BeEquivalentTo(key1.PublicKey().String())) - }) - }) - }) - - Context("when calling GetServerKey endpoint", func() { - Specify("a public Wireguard key of the service is returned", func() { - resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) - - Expect(err).NotTo(HaveOccurred()) - Expect(resp).ToNot(BeNil()) - Expect(resp.Key).ToNot(BeNil()) - Expect(resp.ExpiresAt).ToNot(BeNil()) - - // check if the key is a valid Wireguard key - key, err := wgtypes.ParseKey(resp.Key) - Expect(err).NotTo(HaveOccurred()) - Expect(key).ToNot(BeNil()) - }) - }) - - Context("when calling Login endpoint", func() { - Context("with an invalid setup key", func() { - Specify("an error is returned", func() { - key, _ := wgtypes.GenerateKey() - message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{SetupKey: "invalid setup key", - Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - - resp, err := client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: message, - }) - - Expect(err).To(HaveOccurred()) - Expect(resp).To(BeNil()) - }) - }) - - Context("with a valid setup key", func() { - It("a non error result is returned", func() { - key, _ := wgtypes.GenerateKey() - resp := loginPeerWithValidSetupKey(serverPubKey, key, client) - - Expect(resp).ToNot(BeNil()) - }) - }) - - Context("with a registered peer", func() { - It("a non error result is returned", func() { - key, _ := wgtypes.GenerateKey() - regResp := loginPeerWithValidSetupKey(serverPubKey, key, client) - Expect(regResp).NotTo(BeNil()) - - // just login without registration - message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - loginResp, err := client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: message, - }) - - Expect(err).NotTo(HaveOccurred()) - - decryptedResp := &mgmtProto.LoginResponse{} - err = encryption.DecryptMessage(serverPubKey, key, loginResp.Body, decryptedResp) - Expect(err).NotTo(HaveOccurred()) - - expectedSignalConfig := &mgmtProto.HostConfig{ - Uri: "signal.netbird.io:10000", - Protocol: mgmtProto.HostConfig_HTTP, - } - expectedStunsConfig := &mgmtProto.HostConfig{ - Uri: "stun:stun.netbird.io:3468", - Protocol: mgmtProto.HostConfig_UDP, - } - expectedTurnsConfig := &mgmtProto.ProtectedHostConfig{ - HostConfig: &mgmtProto.HostConfig{ - Uri: "turn:stun.netbird.io:3468", - Protocol: mgmtProto.HostConfig_UDP, - }, - User: "some_user", - Password: "some_password", - } - - Expect(decryptedResp.GetNetbirdConfig().Signal).To(BeEquivalentTo(expectedSignalConfig)) - Expect(decryptedResp.GetNetbirdConfig().Stuns).To(ConsistOf(expectedStunsConfig)) - Expect(decryptedResp.GetNetbirdConfig().Turns).To(ConsistOf(expectedTurnsConfig)) - }) - }) - }) - - Context("when there are 10 peers registered under one account", func() { - Context("when there are 10 more peers registered under the same account", func() { - Specify("all of the 10 peers will get updates of 10 newly registered peers", func() { - initialPeers := 10 - additionalPeers := 10 - - var peers []wgtypes.Key - for i := 0; i < initialPeers; i++ { - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - peers = append(peers, key) - } - - wg := sync2.WaitGroup{} - wg.Add(initialPeers + initialPeers*additionalPeers) - - var clients []mgmtProto.ManagementService_SyncClient - for _, peer := range peers { - messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer) - Expect(err).NotTo(HaveOccurred()) - - // open stream - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: peer.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - clients = append(clients, sync) - - // receive stream - peer := peer - go func() { - for { - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - if err != nil { - break - } - decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, peer) - Expect(err).NotTo(HaveOccurred()) - - resp := &mgmtProto.SyncResponse{} - err = pb.Unmarshal(decryptedBytes, resp) - Expect(err).NotTo(HaveOccurred()) - if len(resp.GetRemotePeers()) > 0 { - // only consider peer updates - wg.Done() - } - } - }() - } - - time.Sleep(1 * time.Second) - for i := 0; i < additionalPeers; i++ { - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - r := rand.New(rand.NewSource(time.Now().UnixNano())) - n := r.Intn(200) - time.Sleep(time.Duration(n) * time.Millisecond) - } - - wg.Wait() - - for _, syncClient := range clients { - err := syncClient.CloseSend() - Expect(err).NotTo(HaveOccurred()) - } - }) - }) - }) - - Context("when there are peers registered under one account concurrently", func() { - Specify("then there are no duplicate IPs", func() { - initialPeers := 30 - - ipChannel := make(chan string, 20) - for i := 0; i < initialPeers; i++ { - go func() { - defer GinkgoRecover() - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} - encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq) - Expect(err).NotTo(HaveOccurred()) - - // open stream - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - Expect(err).NotTo(HaveOccurred()) - - resp := &mgmtProto.SyncResponse{} - err = encryption.DecryptMessage(serverPubKey, key, encryptedResponse.Body, resp) - Expect(err).NotTo(HaveOccurred()) - - ipChannel <- resp.GetPeerConfig().Address - }() - } - - ips := make(map[string]struct{}) - for ip := range ipChannel { - if _, ok := ips[ip]; ok { - Fail("found duplicate IP: " + ip) - } - ips[ip] = struct{}{} - if len(ips) == initialPeers { - break - } - } - close(ipChannel) - }) - }) - - Context("after login two peers", func() { - Specify("then they receive the same network", func() { - key, _ := wgtypes.GenerateKey() - firstLogin := loginPeerWithValidSetupKey(serverPubKey, key, client) - key, _ = wgtypes.GenerateKey() - secondLogin := loginPeerWithValidSetupKey(serverPubKey, key, client) - - _, firstLoginNetwork, err := net.ParseCIDR(firstLogin.GetPeerConfig().GetAddress()) - Expect(err).NotTo(HaveOccurred()) - _, secondLoginNetwork, err := net.ParseCIDR(secondLogin.GetPeerConfig().GetAddress()) - Expect(err).NotTo(HaveOccurred()) - - Expect(secondLoginNetwork.String()).To(BeEquivalentTo(firstLoginNetwork.String())) - }) - }) -}) - -func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { - defer GinkgoRecover() - +type testSuite struct { + t *testing.T + addr string + grpcServer *grpc.Server + dataDir string + client mgmtProto.ManagementServiceClient + serverPubKey wgtypes.Key + conn *grpc.ClientConn +} + +func setupTest(t *testing.T) *testSuite { + t.Helper() + level, _ := log.ParseLevel("Debug") + log.SetLevel(level) + + ts := &testSuite{t: t} + + var err error + ts.dataDir, err = os.MkdirTemp("", "netbird_mgmt_test_tmp_*") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } + + config := &server.Config{} + _, err = util.ReadJson("testdata/management.json", config) + if err != nil { + t.Fatalf("failed to read management.json: %v", err) + } + config.Datadir = ts.dataDir + + var listener net.Listener + ts.grpcServer, listener = startServer(t, config, ts.dataDir, "testdata/store.sql") + ts.addr = listener.Addr().String() + + ts.client, ts.conn = createRawClient(t, ts.addr) + + resp, err := ts.client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) + if err != nil { + t.Fatalf("failed to get server key: %v", err) + } + + serverKey, err := wgtypes.ParseKey(resp.Key) + if err != nil { + t.Fatalf("failed to parse server key: %v", err) + } + ts.serverPubKey = serverKey + + return ts +} + +func tearDownTest(t *testing.T, ts *testSuite) { + t.Helper() + ts.grpcServer.Stop() + if err := ts.conn.Close(); err != nil { + t.Fatalf("failed to close client connection: %v", err) + } + time.Sleep(100 * time.Millisecond) + if err := os.RemoveAll(ts.dataDir); err != nil { + t.Fatalf("failed to remove data directory %s: %v", ts.dataDir, err) + } +} + +func loginPeerWithValidSetupKey( + t *testing.T, + serverPubKey wgtypes.Key, + key wgtypes.Key, + client mgmtProto.ManagementServiceClient, +) *mgmtProto.LoginResponse { + t.Helper() meta := &mgmtProto.PeerSystemMeta{ Hostname: key.PublicKey().String(), GoOS: runtime.GOOS, @@ -457,23 +110,30 @@ func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, clien Kernel: "kernel", NetbirdVersion: "", } - message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{SetupKey: ValidSetupKey, Meta: meta}) - Expect(err).NotTo(HaveOccurred()) + msgToEncrypt := &mgmtProto.LoginRequest{SetupKey: ValidSetupKey, Meta: meta} + message, err := encryption.EncryptMessage(serverPubKey, key, msgToEncrypt) + if err != nil { + t.Fatalf("failed to encrypt login request: %v", err) + } resp, err := client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ WgPubKey: key.PublicKey().String(), Body: message, }) - - Expect(err).NotTo(HaveOccurred()) + if err != nil { + t.Fatalf("login request failed: %v", err) + } loginResp := &mgmtProto.LoginResponse{} err = encryption.DecryptMessage(serverPubKey, key, resp.Body, loginResp) - Expect(err).NotTo(HaveOccurred()) + if err != nil { + t.Fatalf("failed to decrypt login response: %v", err) + } return loginResp } -func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) { +func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) { + t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -484,17 +144,27 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie Time: 10 * time.Second, Timeout: 2 * time.Second, })) - Expect(err).NotTo(HaveOccurred()) + if err != nil { + t.Fatalf("failed to dial gRPC server: %v", err) + } return mgmtProto.NewManagementServiceClient(conn), conn } -func startServer(config *server.Config, dataDir string, testFile string) (*grpc.Server, net.Listener) { +func startServer( + t *testing.T, + config *server.Config, + dataDir string, + testFile string, +) (*grpc.Server, net.Listener) { + t.Helper() lis, err := net.Listen("tcp", ":0") - Expect(err).NotTo(HaveOccurred()) + if err != nil { + t.Fatalf("failed to listen on a random port: %v", err) + } s := grpc.NewServer() - store, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir) + str, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } @@ -504,23 +174,529 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) if err != nil { - log.Fatalf("failed creating metrics: %v", err) + t.Fatalf("failed creating metrics: %v", err) } - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics) + accountManager, err := server.BuildManager( + context.Background(), + str, + peersUpdateManager, + nil, + "", + "netbird.selfhosted", + eventStore, + nil, + false, + server.MocIntegratedValidator{}, + metrics, + ) if err != nil { - log.Fatalf("failed creating a manager: %v", err) + t.Fatalf("failed creating an account manager: %v", err) } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) - Expect(err).NotTo(HaveOccurred()) + mgmtServer, err := server.NewServer( + context.Background(), + config, + accountManager, + settings.NewManager(str), + peersUpdateManager, + secretsManager, + nil, + nil, + ) + if err != nil { + t.Fatalf("failed creating management server: %v", err) + } + mgmtProto.RegisterManagementServiceServer(s, mgmtServer) + go func() { if err := s.Serve(lis); err != nil { - Expect(err).NotTo(HaveOccurred()) + t.Errorf("failed to serve gRPC: %v", err) + return } }() return s, lis } + +func TestIsHealthy(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + healthy, err := ts.client.IsHealthy(context.TODO(), &mgmtProto.Empty{}) + if err != nil { + t.Fatalf("IsHealthy call returned an error: %v", err) + } + if healthy == nil { + t.Fatal("IsHealthy returned a nil response") + } +} + +func TestSyncNewPeerConfiguration(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + encryptedBytes, err := encryption.EncryptMessage(ts.serverPubKey, peerKey, syncReq) + if err != nil { + t.Fatalf("failed to encrypt sync request: %v", err) + } + + syncStream, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedBytes, + }) + if err != nil { + t.Fatalf("failed to call Sync: %v", err) + } + + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = syncStream.RecvMsg(encryptedResponse) + if err != nil { + t.Fatalf("failed to receive sync response message: %v", err) + } + + resp := &mgmtProto.SyncResponse{} + err = encryption.DecryptMessage(ts.serverPubKey, peerKey, encryptedResponse.Body, resp) + if err != nil { + t.Fatalf("failed to decrypt sync response: %v", err) + } + + expectedSignalConfig := &mgmtProto.HostConfig{ + Uri: "signal.netbird.io:10000", + Protocol: mgmtProto.HostConfig_HTTP, + } + expectedStunsConfig := &mgmtProto.HostConfig{ + Uri: "stun:stun.netbird.io:3468", + Protocol: mgmtProto.HostConfig_UDP, + } + expectedTRUNHost := &mgmtProto.HostConfig{ + Uri: "turn:stun.netbird.io:3468", + Protocol: mgmtProto.HostConfig_UDP, + } + + assert.NotNil(t, resp.NetbirdConfig) + assert.Equal(t, resp.NetbirdConfig.Signal, expectedSignalConfig) + assert.Contains(t, resp.NetbirdConfig.Stuns, expectedStunsConfig) + assert.Equal(t, len(resp.NetbirdConfig.Turns), 1) + actualTURN := resp.NetbirdConfig.Turns[0] + assert.Greater(t, len(actualTURN.User), 0) + assert.Equal(t, actualTURN.HostConfig, expectedTRUNHost) + assert.Equal(t, len(resp.NetworkMap.OfflinePeers), 0) +} + +func TestSyncThreePeers(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + peerKey1, _ := wgtypes.GenerateKey() + peerKey2, _ := wgtypes.GenerateKey() + + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey1, ts.client) + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey2, ts.client) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + syncBytes, err := pb.Marshal(syncReq) + if err != nil { + t.Fatalf("failed to marshal sync request: %v", err) + } + encryptedBytes, err := encryption.Encrypt(syncBytes, ts.serverPubKey, peerKey) + if err != nil { + t.Fatalf("failed to encrypt sync request: %v", err) + } + + syncStream, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedBytes, + }) + if err != nil { + t.Fatalf("failed to call Sync: %v", err) + } + + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = syncStream.RecvMsg(encryptedResponse) + if err != nil { + t.Fatalf("failed to receive sync response: %v", err) + } + + decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, ts.serverPubKey, peerKey) + if err != nil { + t.Fatalf("failed to decrypt sync response: %v", err) + } + + resp := &mgmtProto.SyncResponse{} + err = pb.Unmarshal(decryptedBytes, resp) + if err != nil { + t.Fatalf("failed to unmarshal sync response: %v", err) + } + + if len(resp.GetRemotePeers()) != 2 { + t.Fatalf("expected 2 remote peers, got %d", len(resp.GetRemotePeers())) + } + + var found1, found2 bool + for _, rp := range resp.GetRemotePeers() { + if rp.WgPubKey == peerKey1.PublicKey().String() { + found1 = true + } else if rp.WgPubKey == peerKey2.PublicKey().String() { + found2 = true + } + } + if !found1 || !found2 { + t.Fatalf("did not find the expected peer keys %s, %s among %v", + peerKey1.PublicKey().String(), + peerKey2.PublicKey().String(), + resp.GetRemotePeers()) + } +} + +func TestSyncNewPeerUpdate(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + syncBytes, err := pb.Marshal(syncReq) + if err != nil { + t.Fatalf("failed to marshal sync request: %v", err) + } + + encryptedBytes, err := encryption.Encrypt(syncBytes, ts.serverPubKey, peerKey) + if err != nil { + t.Fatalf("failed to encrypt sync request: %v", err) + } + + syncStream, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedBytes, + }) + if err != nil { + t.Fatalf("failed to call Sync: %v", err) + } + + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = syncStream.RecvMsg(encryptedResponse) + if err != nil { + t.Fatalf("failed to receive first sync response: %v", err) + } + + decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, ts.serverPubKey, peerKey) + if err != nil { + t.Fatalf("failed to decrypt first sync response: %v", err) + } + + resp := &mgmtProto.SyncResponse{} + if err := pb.Unmarshal(decryptedBytes, resp); err != nil { + t.Fatalf("failed to unmarshal first sync response: %v", err) + } + + if len(resp.GetRemotePeers()) != 0 { + t.Fatalf("expected 0 remote peers at first sync, got %d", len(resp.GetRemotePeers())) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = syncStream.RecvMsg(encryptedResponse) + if err != nil { + t.Errorf("failed to receive second sync response: %v", err) + return + } + + decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, ts.serverPubKey, peerKey) + if err != nil { + t.Errorf("failed to decrypt second sync response: %v", err) + return + } + err = pb.Unmarshal(decryptedBytes, resp) + if err != nil { + t.Errorf("failed to unmarshal second sync response: %v", err) + return + } + }() + + newPeerKey, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, newPeerKey, ts.client) + + wg.Wait() + + if len(resp.GetRemotePeers()) != 1 { + t.Fatalf("expected exactly 1 remote peer update, got %d", len(resp.GetRemotePeers())) + } + if resp.GetRemotePeers()[0].WgPubKey != newPeerKey.PublicKey().String() { + t.Fatalf("expected new peer key %s, got %s", + newPeerKey.PublicKey().String(), + resp.GetRemotePeers()[0].WgPubKey) + } +} + +func TestGetServerKey(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + resp, err := ts.client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) + if err != nil { + t.Fatalf("GetServerKey returned error: %v", err) + } + if resp == nil { + t.Fatal("GetServerKey returned nil response") + } + if resp.Key == "" { + t.Fatal("GetServerKey returned empty key") + } + if resp.ExpiresAt.AsTime().IsZero() { + t.Fatal("GetServerKey returned 0 for ExpiresAt") + } + + _, err = wgtypes.ParseKey(resp.Key) + if err != nil { + t.Fatalf("GetServerKey returned an invalid WG key: %v", err) + } +} + +func TestLoginInvalidSetupKey(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + request := &mgmtProto.LoginRequest{ + SetupKey: "invalid setup key", + Meta: &mgmtProto.PeerSystemMeta{}, + } + encryptedMsg, err := encryption.EncryptMessage(ts.serverPubKey, peerKey, request) + if err != nil { + t.Fatalf("failed to encrypt login request: %v", err) + } + + resp, err := ts.client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedMsg, + }) + if err == nil { + t.Fatal("expected error for invalid setup key but got nil") + } + if resp != nil { + t.Fatalf("expected nil response for invalid setup key but got: %+v", resp) + } +} + +func TestLoginValidSetupKey(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + resp := loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + if resp == nil { + t.Fatal("loginPeerWithValidSetupKey returned nil, expected a valid response") + } +} + +func TestLoginRegisteredPeer(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + regResp := loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + if regResp == nil { + t.Fatal("registration with valid setup key failed") + } + + loginReq := &mgmtProto.LoginRequest{Meta: &mgmtProto.PeerSystemMeta{}} + encryptedLogin, err := encryption.EncryptMessage(ts.serverPubKey, peerKey, loginReq) + if err != nil { + t.Fatalf("failed to encrypt login request: %v", err) + } + loginRespEnc, err := ts.client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedLogin, + }) + if err != nil { + t.Fatalf("login call returned an error: %v", err) + } + + loginResp := &mgmtProto.LoginResponse{} + err = encryption.DecryptMessage(ts.serverPubKey, peerKey, loginRespEnc.Body, loginResp) + if err != nil { + t.Fatalf("failed to decrypt login response: %v", err) + } + + expectedSignalConfig := &mgmtProto.HostConfig{ + Uri: "signal.netbird.io:10000", + Protocol: mgmtProto.HostConfig_HTTP, + } + expectedStunsConfig := &mgmtProto.HostConfig{ + Uri: "stun:stun.netbird.io:3468", + Protocol: mgmtProto.HostConfig_UDP, + } + expectedTurnsConfig := &mgmtProto.ProtectedHostConfig{ + HostConfig: &mgmtProto.HostConfig{ + Uri: "turn:stun.netbird.io:3468", + Protocol: mgmtProto.HostConfig_UDP, + }, + User: "some_user", + Password: "some_password", + } + + assert.NotNil(t, loginResp.GetNetbirdConfig()) + assert.Equal(t, loginResp.GetNetbirdConfig().Signal, expectedSignalConfig) + assert.Contains(t, loginResp.GetNetbirdConfig().Stuns, expectedStunsConfig) + assert.Contains(t, loginResp.GetNetbirdConfig().Turns, expectedTurnsConfig) +} + +func TestSync10PeersGetUpdates(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + initialPeers := 10 + additionalPeers := 10 + + var peers []wgtypes.Key + for i := 0; i < initialPeers; i++ { + key, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, key, ts.client) + peers = append(peers, key) + } + + var wg sync.WaitGroup + wg.Add(initialPeers + initialPeers*additionalPeers) + + var syncClients []mgmtProto.ManagementService_SyncClient + for _, pk := range peers { + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + msgBytes, err := pb.Marshal(syncReq) + if err != nil { + t.Fatalf("failed to marshal SyncRequest: %v", err) + } + encBytes, err := encryption.Encrypt(msgBytes, ts.serverPubKey, pk) + if err != nil { + t.Fatalf("failed to encrypt SyncRequest: %v", err) + } + + s, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: pk.PublicKey().String(), + Body: encBytes, + }) + if err != nil { + t.Fatalf("failed to call Sync for peer: %v", err) + } + syncClients = append(syncClients, s) + + go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) { + for { + encMsg := &mgmtProto.EncryptedMessage{} + err := syncStream.RecvMsg(encMsg) + if err != nil { + return + } + decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk) + if decErr != nil { + t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr) + return + } + resp := &mgmtProto.SyncResponse{} + umErr := pb.Unmarshal(decryptedBytes, resp) + if umErr != nil { + t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr) + return + } + // We only count if there's a new peer update + if len(resp.GetRemotePeers()) > 0 { + wg.Done() + } + } + }(pk, s) + } + + time.Sleep(500 * time.Millisecond) + for i := 0; i < additionalPeers; i++ { + key, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, key, ts.client) + r := rand.New(rand.NewSource(time.Now().UnixNano())) + n := r.Intn(200) + time.Sleep(time.Duration(n) * time.Millisecond) + } + + wg.Wait() + + for _, sc := range syncClients { + err := sc.CloseSend() + if err != nil { + t.Fatalf("failed to close sync client: %v", err) + } + } +} + +func TestConcurrentPeersNoDuplicateIPs(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + initialPeers := 30 + ipChan := make(chan string, initialPeers) + + var wg sync.WaitGroup + wg.Add(initialPeers) + + for i := 0; i < initialPeers; i++ { + go func() { + defer wg.Done() + key, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, key, ts.client) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + encryptedBytes, err := encryption.EncryptMessage(ts.serverPubKey, key, syncReq) + if err != nil { + t.Errorf("failed to encrypt sync request: %v", err) + return + } + + s, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: key.PublicKey().String(), + Body: encryptedBytes, + }) + if err != nil { + t.Errorf("failed to call Sync: %v", err) + return + } + + encResp := &mgmtProto.EncryptedMessage{} + if err = s.RecvMsg(encResp); err != nil { + t.Errorf("failed to receive sync response: %v", err) + return + } + + resp := &mgmtProto.SyncResponse{} + if err = encryption.DecryptMessage(ts.serverPubKey, key, encResp.Body, resp); err != nil { + t.Errorf("failed to decrypt sync response: %v", err) + return + } + ipChan <- resp.GetPeerConfig().Address + }() + } + + wg.Wait() + close(ipChan) + + ipMap := make(map[string]bool) + for ip := range ipChan { + if ipMap[ip] { + t.Fatalf("found duplicate IP: %s", ip) + } + ipMap[ip] = true + } + + // Ensure we collected all peers + if len(ipMap) != initialPeers { + t.Fatalf("expected %d unique IPs, got %d", initialPeers, len(ipMap)) + } +} diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 0743db513..497d9af4f 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -379,12 +379,12 @@ func TestCreateNameServerGroup(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { am, err := createNSManager(t) if err != nil { - t.Error("failed to create account manager") + t.Fatalf("failed to create account manager: %s", err) } account, err := initTestNSAccount(t, am) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } outNSGroup, err := am.CreateNameServerGroup( @@ -607,12 +607,12 @@ func TestSaveNameServerGroup(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { am, err := createNSManager(t) if err != nil { - t.Error("failed to create account manager") + t.Fatalf("failed to create account manager: %s", err) } account, err := initTestNSAccount(t, am) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup @@ -706,7 +706,7 @@ func TestDeleteNameServerGroup(t *testing.T) { account, err := initTestNSAccount(t, am) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } account.NameServerGroups[testingNSGroup.ID] = testingNSGroup @@ -741,7 +741,7 @@ func TestGetNameServerGroup(t *testing.T) { account, err := initTestNSAccount(t, am) if err != nil { - t.Error("failed to init testing account") + t.Fatalf("failed to init testing account: %s", err) } foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID) @@ -761,6 +761,7 @@ func TestGetNameServerGroup(t *testing.T) { func createNSManager(t *testing.T) (*DefaultAccountManager, error) { t.Helper() + store, err := createNSStore(t) if err != nil { return nil, err diff --git a/management/server/route_test.go b/management/server/route_test.go index 1c5c56f60..40e0f41b0 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -13,12 +13,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/management/server/activity" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - - "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 6e04c7d9d..dd240ce6c 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -37,40 +37,44 @@ import ( nbroute "github.com/netbirdio/netbird/route" ) -func TestSqlite_NewStore(t *testing.T) { +func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { + t.Helper() + for _, engine := range supportedEngines { + if os.Getenv("NETBIRD_STORE_ENGINE") != "" && os.Getenv("NETBIRD_STORE_ENGINE") != string(engine) { + continue + } + t.Setenv("NETBIRD_STORE_ENGINE", string(engine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), testDataFile, t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + t.Run(string(engine), func(t *testing.T) { + f(t, store) + }) + os.Unsetenv("NETBIRD_STORE_ENGINE") + } +} + +func Test_NewStore(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - if len(store.GetAllAccounts(context.Background())) != 0 { - t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") - } + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Errorf("expected to create a new Store") + } + if len(store.GetAllAccounts(context.Background())) != 0 { + t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") + } + }) } -func TestSqlite_SaveAccount_Large(t *testing.T) { +func Test_SaveAccount_Large(t *testing.T) { if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { t.Skip("skip CI tests on darwin and windows") } - t.Run("SQLite", func(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - runLargeTest(t, store) - }) - - // create store outside to have a better time counter for the test - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - t.Run("PostgreSQL", func(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { runLargeTest(t, store) }) } @@ -215,77 +219,74 @@ func randomIPv4() net.IP { return net.IP(b) } -func TestSqlite_SaveAccount(t *testing.T) { +func Test_SaveAccount(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + account := newAccountWithId(context.Background(), "account_id", "testuser", "") + setupKey, _ := types.GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + account.Peers["testpeer"] = &nbpeer.Peer{ + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := types.GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err) - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") + setupKey, _ = types.GenerateDefaultSetupKey() + account2.SetupKeys[setupKey.Key] = setupKey + account2.Peers["testpeer2"] = &nbpeer.Peer{ + Key: "peerkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } - account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = types.GenerateDefaultSetupKey() - account2.SetupKeys[setupKey.Key] = setupKey - account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + err = store.SaveAccount(context.Background(), account2) + require.NoError(t, err) - err = store.SaveAccount(context.Background(), account2) - require.NoError(t, err) + if len(store.GetAllAccounts(context.Background())) != 2 { + t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") + } - if len(store.GetAllAccounts(context.Background())) != 2 { - t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") - } + a, err := store.GetAccount(context.Background(), account.Id) + if a == nil { + t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) + } - a, err := store.GetAccount(context.Background(), account.Id) - if a == nil { - t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) - } + if a != nil && len(a.Policies) != 1 { + t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies)) + } - if a != nil && len(a.Policies) != 1 { - t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies)) - } + if a != nil && len(a.Policies[0].Rules) != 1 { + t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules)) + return + } - if a != nil && len(a.Policies[0].Rules) != 1 { - t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules)) - return - } + if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil { + t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) + } - if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil { - t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) - } + if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil { + t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) + } - if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil { - t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) - } + if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil { + t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) + } - if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil { - t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) - } - - if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil { - t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) - } + if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil { + t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) + } + }) } func TestSqlite_DeleteAccount(t *testing.T) { @@ -402,27 +403,24 @@ func TestSqlite_DeleteAccount(t *testing.T) { } } -func TestSqlite_GetAccount(t *testing.T) { +func Test_GetAccount(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) + runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) { + id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + account, err := store.GetAccount(context.Background(), id) + require.NoError(t, err) + require.Equal(t, id, account.Id, "account id should match") - account, err := store.GetAccount(context.Background(), id) - require.NoError(t, err) - require.Equal(t, id, account.Id, "account id should match") - - _, err = store.GetAccount(context.Background(), "non-existing-account") - assert.Error(t, err) - parsedErr, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + _, err = store.GetAccount(context.Background(), "non-existing-account") + assert.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } func TestSqlStore_SavePeer(t *testing.T) { @@ -580,51 +578,45 @@ func TestSqlStore_SavePeerLocation(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } -func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { +func Test_TestGetAccountByPrivateDomain(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) + runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) { + existingDomain := "test.com" - existingDomain := "test.com" + account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) + require.NoError(t, err, "should found account") + require.Equal(t, existingDomain, account.Domain, "domains should match") - account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) - require.NoError(t, err, "should found account") - require.Equal(t, existingDomain, account.Domain, "domains should match") - - _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") - require.Error(t, err, "should return error on domain lookup") - parsedErr, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") + require.Error(t, err, "should return error on domain lookup") + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } -func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { +func Test_GetTokenIDByHashedToken(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) + runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) { + hashed := "SoMeHaShEdToKeN" + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - hashed := "SoMeHaShEdToKeN" - id := "9dj38s35-63fb-11ec-90d6-0242ac120003" + token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) + require.NoError(t, err) + require.Equal(t, id, token) - token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) - require.NoError(t, err) - require.Equal(t, id, token) - - _, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash") - require.Error(t, err) - parsedErr, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + _, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash") + require.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } func TestMigrate(t *testing.T) { diff --git a/management/server/store/store.go b/management/server/store/store.go index 29ed22fa5..e074c4c60 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -9,11 +9,16 @@ import ( "os" "path" "path/filepath" + "regexp" "runtime" + "slices" "strings" "time" + "github.com/google/uuid" log "github.com/sirupsen/logrus" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -193,6 +198,8 @@ const ( mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN" ) +var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine} + func getStoreEngineFromEnv() Engine { // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file. kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") @@ -201,7 +208,7 @@ func getStoreEngineFromEnv() Engine { } value := Engine(strings.ToLower(kind)) - if value == SqliteStoreEngine || value == PostgresStoreEngine || value == MysqlStoreEngine { + if slices.Contains(supportedEngines, value) { return value } @@ -349,51 +356,126 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( } func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) { - if kind == PostgresStoreEngine { - cleanUp, err := testutil.CreatePostgresTestContainer() - if err != nil { - return nil, nil, err + var cleanup func() + var err error + switch kind { + case PostgresStoreEngine: + store, cleanup, err = newReusedPostgresStore(ctx, store, kind) + case MysqlStoreEngine: + store, cleanup, err = newReusedMysqlStore(ctx, store, kind) + default: + cleanup = func() { + // sqlite doesn't need to be cleaned up } - - dsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) - } - - store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) - if err != nil { - return nil, nil, err - } - - return store, cleanUp, nil } - - if kind == MysqlStoreEngine { - cleanUp, err := testutil.CreateMysqlTestContainer() - if err != nil { - return nil, nil, err - } - - dsn, ok := os.LookupEnv(mysqlDsnEnv) - if !ok { - return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) - } - - store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil) - if err != nil { - return nil, nil, err - } - - return store, cleanUp, nil + if err != nil { + return nil, cleanup, fmt.Errorf("failed to create test store: %v", err) } closeConnection := func() { + cleanup() store.Close(ctx) } return store, closeConnection, nil } +func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) { + if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" { + var err error + _, err = testutil.CreatePostgresTestContainer() + if err != nil { + return nil, nil, err + } + } + + dsn, ok := os.LookupEnv(postgresDsnEnv) + if !ok { + return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) + } + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err) + } + + dsn, cleanup, err := createRandomDB(dsn, db, kind) + if err != nil { + return nil, cleanup, err + } + + store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) + if err != nil { + return nil, cleanup, err + } + + return store, cleanup, nil +} + +func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) { + if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" { + var err error + _, err = testutil.CreateMysqlTestContainer() + if err != nil { + return nil, nil, err + } + } + + dsn, ok := os.LookupEnv(mysqlDsnEnv) + if !ok { + return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) + } + + db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + if err != nil { + return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err) + } + + dsn, cleanup, err := createRandomDB(dsn, db, kind) + if err != nil { + return nil, cleanup, err + } + + store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil) + if err != nil { + return nil, nil, err + } + + return store, cleanup, nil +} + +func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) { + dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_")) + + if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil { + return "", nil, fmt.Errorf("failed to create database: %v", err) + } + + var err error + cleanup := func() { + switch engine { + case PostgresStoreEngine: + err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error + case MysqlStoreEngine: + // err = killMySQLConnections(dsn, dbName) + err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error + } + if err != nil { + log.Errorf("failed to drop database %s: %v", dbName, err) + panic(err) + } + sqlDB, _ := db.DB() + _ = sqlDB.Close() + } + + return replaceDBName(dsn, dbName), cleanup, nil +} + +func replaceDBName(dsn, newDBName string) string { + re := regexp.MustCompile(`(?P
[:/@])(?P[^/?]+)(?P \?|$)`) + return re.ReplaceAllString(dsn, `${pre}`+newDBName+`${post}`) +} + func loadSQL(db *gorm.DB, filepath string) error { sqlContent, err := os.ReadFile(filepath) if err != nil { diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index 16438cab8..8672efa7f 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -22,7 +22,7 @@ func CreateMysqlTestContainer() (func(), error) { myContainer, err := mysql.RunContainer(ctx, testcontainers.WithImage("mlsmaycon/warmed-mysql:8"), mysql.WithDatabase("testing"), - mysql.WithUsername("testing"), + mysql.WithUsername("root"), mysql.WithPassword("testing"), testcontainers.WithWaitStrategy( wait.ForLog("/usr/sbin/mysqld: ready for connections"). @@ -34,6 +34,7 @@ func CreateMysqlTestContainer() (func(), error) { } cleanup := func() { + os.Unsetenv("NETBIRD_STORE_ENGINE_MYSQL_DSN") timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) defer cancelFunc() if err = myContainer.Terminate(timeoutCtx); err != nil { @@ -68,6 +69,7 @@ func CreatePostgresTestContainer() (func(), error) { } cleanup := func() { + os.Unsetenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN") timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) defer cancelFunc() if err = pgContainer.Terminate(timeoutCtx); err != nil { diff --git a/management/server/types/user.go b/management/server/types/user.go index 348fbfb22..5f7a4f2cb 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -80,7 +80,7 @@ type User struct { // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user AutoGroups []string `gorm:"serializer:json"` PATs map[string]*PersonalAccessToken `gorm:"-"` - PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` + PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool // LastLogin is the last time the user logged in to IdP diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go index b007e24bb..cb525865b 100644 --- a/relay/client/dialer/ws/ws.go +++ b/relay/client/dialer/ws/ws.go @@ -11,8 +11,8 @@ import ( "net/url" "strings" - log "github.com/sirupsen/logrus" "github.com/coder/websocket" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/util/embeddedroots" diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go index 3466b2abd..3ec08945b 100644 --- a/relay/server/listener/ws/conn.go +++ b/relay/server/listener/ws/conn.go @@ -8,8 +8,8 @@ import ( "sync" "time" - log "github.com/sirupsen/logrus" "github.com/coder/websocket" + log "github.com/sirupsen/logrus" ) const ( diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 4597669dc..3a95951ee 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -8,8 +8,8 @@ import ( "net" "net/http" - log "github.com/sirupsen/logrus" "github.com/coder/websocket" + log "github.com/sirupsen/logrus" ) // URLPath is the path for the websocket connection.