diff --git a/go.mod b/go.mod index 2d994b30d..d4b702453 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/c-robinson/iplib v1.0.3 github.com/coreos/go-iptables v0.6.0 github.com/creack/pty v1.1.18 - github.com/eko/gocache/v2 v2.3.1 + github.com/eko/gocache/v3 v3.1.1 github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 @@ -41,7 +41,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/rs/xid v1.3.0 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 - github.com/stretchr/testify v1.7.1 + github.com/stretchr/testify v1.8.0 golang.org/x/net v0.0.0-20220630215102-69896b714898 golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 ) @@ -99,6 +99,7 @@ require ( github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/yuin/goldmark v1.4.1 // indirect + golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect @@ -112,7 +113,7 @@ require ( gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect honnef.co/go/tools v0.2.2 // indirect k8s.io/apimachinery v0.23.5 // indirect ) diff --git a/go.sum b/go.sum index 6bac16984..9acff5236 100644 --- a/go.sum +++ b/go.sum @@ -134,8 +134,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= -github.com/eko/gocache/v2 v2.3.1 h1:8MMkfqGJ0KIA9OXT0rXevcEIrU16oghrGDiIDJDFCa0= -github.com/eko/gocache/v2 v2.3.1/go.mod h1:l2z8OmpZHL0CpuzDJtxm267eF3mZW1NqUsMj+sKrbUs= +github.com/eko/gocache/v3 v3.1.1 h1:r3CBwLnqPkcK56h9Do2CWw1kZ4TeKK0wDE1Oo/YZnhs= +github.com/eko/gocache/v3 v3.1.1/go.mod h1:UpP/LyHAioP/a/dizgl0MpgZ3A3CkS4NbG/mWkGTQ9M= github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= @@ -609,6 +609,7 @@ github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9/go.mod h1:mvWM0+15 github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -616,8 +617,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= @@ -676,6 +678,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf h1:oXVg4h2qJDd9htKxb5SCpFBHLipW6hXmL3qpUixS2jw= +golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf/go.mod h1:yh0Ynu2b5ZUe3MQfp2nM0ecK7wsgouWTDN0FNeJuIys= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20200430140353-33d19683fad8 h1:6WW6V3x1P/jokJBpRQYUJnMHRP6isStQwCozxnU7XQw= @@ -1190,8 +1194,9 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/management/server/account.go b/management/server/account.go index 408898bcf..923da63b9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -3,8 +3,8 @@ package server import ( "context" "fmt" - "github.com/eko/gocache/v2/cache" - cacheStore "github.com/eko/gocache/v2/store" + "github.com/eko/gocache/v3/cache" + cacheStore "github.com/eko/gocache/v3/store" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -30,6 +30,11 @@ const ( CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days ) +func cacheEntryExpiration() time.Duration { + r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds()) + return time.Duration(r) * time.Millisecond +} + type AccountManager interface { GetOrCreateAccountByUser(userId, domain string) (*Account, error) GetAccountByUser(userId string) (*Account, error) @@ -41,12 +46,13 @@ type AccountManager interface { autoGroups []string, ) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error) + CreateUser(accountID string, key *UserInfo) (*UserInfo, error) ListSetupKeys(accountID string) ([]*SetupKey, error) SaveUser(accountID string, key *User) (*UserInfo, error) GetSetupKey(accountID, keyID string) (*SetupKey, error) GetAccountById(accountId string) (*Account, error) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) - GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) + GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeer(peerKey string) (*Peer, error) @@ -90,11 +96,15 @@ type AccountManager interface { type DefaultAccountManager struct { Store Store - // mutex to synchronise account operations (e.g. generating Peer IP address inside the Network) - mux sync.Mutex + // mux to synchronise account operations (e.g. generating Peer IP address inside the Network) + mux sync.Mutex + // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID + cacheMux sync.Mutex + // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded + cacheLoading map[string]chan struct{} peersUpdateManager *PeersUpdateManager idpManager idp.Manager - cacheManager cache.CacheInterface + cacheManager cache.CacheInterface[[]*idp.UserData] ctx context.Context } @@ -122,6 +132,7 @@ type UserInfo struct { Name string `json:"name"` Role string `json:"role"` AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` } func (a *Account) Copy() *Account { @@ -193,6 +204,8 @@ func BuildManager( peersUpdateManager: peersUpdateManager, idpManager: idpManager, ctx: context.Background(), + cacheMux: sync.Mutex{}, + cacheLoading: map[string]chan struct{}{}, } // if account has not default group @@ -209,9 +222,9 @@ func BuildManager( } gocacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) - gocacheStore := cacheStore.NewGoCache(gocacheClient, nil) + gocacheStore := cacheStore.NewGoCache(gocacheClient) - am.cacheManager = cache.NewLoadable(am.loadFromCache, cache.New(gocacheStore)) + am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](gocacheStore)) if !isNil(am.idpManager) { go func() { @@ -256,11 +269,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error { } for accountID, users := range userData { - rand.Seed(time.Now().UnixNano()) - - r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds()) - expiration := time.Duration(r) * time.Millisecond - err = am.cacheManager.Set(am.ctx, accountID, users, &cacheStore.Options{Expiration: expiration}) + err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration())) if err != nil { return err } @@ -294,7 +303,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountId( if err != nil { return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId) } - err = am.updateIDPMetadata(userId, account.Id) + err = am.addAccountIDToIDPAppMeta(userId, account) if err != nil { return nil, err } @@ -308,10 +317,28 @@ func isNil(i idp.Manager) bool { return i == nil || reflect.ValueOf(i).IsNil() } -// updateIDPMetadata update user's app metadata in idp manager -func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) error { +// addAccountIDToIDPAppMeta update user's app metadata in idp manager +func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account *Account) error { if !isNil(am.idpManager) { - err := am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: accountID}) + + // user can be nil if it wasn't found (e.g., just created) + user, err := am.lookupUserInCache(userID, account) + if err != nil { + return err + } + + if user != nil && user.AppMetadata.WTAccountID == account.Id { + // it was already set, so we skip the unnecessary update + log.Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", + account.Id, userID) + return nil + } + + err = am.idpManager.UpdateUserAppMetadata(userID, idp.AppMetadata{WTAccountID: account.Id}) + if err != nil { + return err + } + if err != nil { return status.Errorf( codes.Internal, @@ -319,45 +346,113 @@ func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) err err, ) } + // refresh cache to reflect the update + _, err = am.refreshCache(account.Id) + if err != nil { + return err + } } return nil } -func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) { +func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) { + log.Debugf("account %s not found in cache, reloading", accountID) return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID)) } -func (am *DefaultAccountManager) lookupUserInCache(user *User, accountID string) (*idp.UserData, error) { - userData, err := am.lookupCache(map[string]*User{user.Id: user}, accountID) +func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) { + data, err := am.getAccountFromCache(accountID, false) if err != nil { return nil, err } - for _, datum := range userData { - if datum.ID == user.Id { + for _, datum := range data { + if datum.Email == email { return datum, nil } } - return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", user.Id) + return nil, nil + } -func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, accountID string) ([]*idp.UserData, error) { - data, err := am.cacheManager.Get(am.ctx, accountID) +// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil +func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) { + users := make(map[string]struct{}, len(account.Users)) + for _, user := range account.Users { + users[user.Id] = struct{}{} + } + log.Debugf("looking up user %s of account %s in cache", userID, account.Id) + userData, err := am.lookupCache(users, account.Id) if err != nil { return nil, err } - userData := data.([]*idp.UserData) + for _, datum := range userData { + if datum.ID == userID { + return datum, nil + } + } + + return nil, nil +} + +func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) { + return am.getAccountFromCache(accountID, true) +} + +// getAccountFromCache returns user data for a given account ensuring that cache load happens only once +func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceReload bool) ([]*idp.UserData, error) { + am.cacheMux.Lock() + loadingChan := am.cacheLoading[accountID] + if loadingChan == nil { + loadingChan = make(chan struct{}) + am.cacheLoading[accountID] = loadingChan + am.cacheMux.Unlock() + + defer func() { + am.cacheMux.Lock() + delete(am.cacheLoading, accountID) + close(loadingChan) + am.cacheMux.Unlock() + }() + + if forceReload { + err := am.cacheManager.Delete(am.ctx, accountID) + if err != nil { + return nil, err + } + } + + return am.cacheManager.Get(am.ctx, accountID) + } + am.cacheMux.Unlock() + + log.Debugf("one request to get account %s is already running", accountID) + + select { + case <-loadingChan: + // channel has been closed meaning cache was loaded => simply return from cache + return am.cacheManager.Get(am.ctx, accountID) + case <-time.After(5 * time.Second): + return nil, fmt.Errorf("timeout while waiting for account %s cache to reload", accountID) + } +} + +func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) { + data, err := am.getAccountFromCache(accountID, false) + if err != nil { + return nil, err + } userDataMap := make(map[string]struct{}) - for _, datum := range userData { + for _, datum := range data { userDataMap[datum.ID] = struct{}{} } // check whether we need to reload the cache // the accountUsers ID list is the source of truth and all the users should be in the cache - reload := len(accountUsers) != len(userData) + reload := len(accountUsers) != len(data) for user := range accountUsers { if _, ok := userDataMap[user]; !ok { reload = true @@ -366,19 +461,13 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, acco if reload { // reload cache once avoiding loops - err := am.cacheManager.Delete(am.ctx, accountID) + data, err = am.refreshCache(accountID) if err != nil { return nil, err } - data, err = am.cacheManager.Get(am.ctx, accountID) - if err != nil { - return nil, err - } - - userData = data.([]*idp.UserData) } - return userData, err + return data, err } // updateAccountDomainAttributes updates the account domain attributes and then, saves the account @@ -433,7 +522,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( } // we should register the account ID to this user's metadata in our IDP manager - err = am.updateIDPMetadata(claims.UserId, existingAcc.Id) + err = am.addAccountIDToIDPAppMeta(claims.UserId, existingAcc) if err != nil { return err } @@ -471,7 +560,7 @@ func (am *DefaultAccountManager) handleNewUserAccount( } } - err = am.updateIDPMetadata(claims.UserId, account.Id) + err = am.addAccountIDToIDPAppMeta(claims.UserId, account) if err != nil { return nil, err } @@ -479,7 +568,56 @@ func (am *DefaultAccountManager) handleNewUserAccount( return account, nil } -// GetAccountWithAuthorizationClaims retrievs an account using JWT Claims. +// redeemInvite checks whether user has been invited and redeems the invite +func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) error { + // only possible with the enabled IdP manager + if am.idpManager == nil { + log.Warnf("invites only work with enabled IdP manager") + return nil + } + + user, err := am.lookupUserInCache(userID, account) + if err != nil { + return err + } + + if user == nil { + return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID) + } + + if user.AppMetadata.WTPendingInvite { + log.Infof("redeeming invite for user %s account %s", userID, account.Id) + // User has already logged in, meaning that IdP should have set wt_pending_invite to false. + // Our job is to just reload cache. + go func() { + _, err = am.refreshCache(account.Id) + if err != nil { + log.Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id) + return + } + log.Debugf("user %s of account %s redeemed invite", user.ID, account.Id) + }() + } + + return nil +} + +// GetAccountFromToken returns an account associated with this token +func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) { + account, err := am.getAccountWithAuthorizationClaims(claims) + if err != nil { + return nil, err + } + + err = am.redeemInvite(account, claims.UserId) + if err != nil { + return nil, err + } + + return account, nil +} + +// getAccountWithAuthorizationClaims retrievs an account using JWT Claims. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // @@ -496,7 +634,7 @@ func (am *DefaultAccountManager) handleNewUserAccount( // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims( +func (am *DefaultAccountManager) getAccountWithAuthorizationClaims( claims jwtclaims.AuthorizationClaims, ) (*Account, error) { // if Account ID is part of the claims diff --git a/management/server/account_test.go b/management/server/account_test.go index 77114e83d..c51eaf6a2 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -127,7 +127,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } } -func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { +func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims type test struct { @@ -310,7 +310,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - account, err := manager.GetAccountWithAuthorizationClaims(testCase.inputClaims) + account, err := manager.GetAccountFromToken(testCase.inputClaims) require.NoError(t, err, "support function failed") verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) diff --git a/management/server/error.go b/management/server/error.go new file mode 100644 index 000000000..1f7bce7f8 --- /dev/null +++ b/management/server/error.go @@ -0,0 +1,52 @@ +package server + +import ( + "fmt" +) + +const ( + // UserAlreadyExists indicates that user already exists + UserAlreadyExists ErrorType = 1 + // AccountNotFound indicates that specified account hasn't been found + AccountNotFound ErrorType = iota + // PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled + PreconditionFailed ErrorType = iota +) + +// ErrorType is a type of the Error +type ErrorType int32 + +// Error is an internal error +type Error struct { + errorType ErrorType + message string +} + +// Type returns the Type of the error +func (e *Error) Type() ErrorType { + return e.errorType +} + +// Error is an error string +func (e *Error) Error() string { + return e.message +} + +// Errorf returns Error(errorType, fmt.Sprintf(format, a...)). +func Errorf(errorType ErrorType, format string, a ...interface{}) error { + return &Error{ + errorType: errorType, + message: fmt.Sprintf(format, a...), + } +} + +// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise +func FromError(err error) (s *Error, ok bool) { + if err == nil { + return nil, true + } + if e, ok := err.(*Error); ok { + return e, true + } + return nil, false +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 8f0dc32e8..e0e4ea3f9 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -181,7 +181,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err) } claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience) - _, err = s.accountManager.GetAccountWithAuthorizationClaims(claims) + _, err = s.accountManager.GetAccountFromToken(claims) if err != nil { return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 28a2419d2..88ac357dd 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -35,6 +35,10 @@ components: role: description: User's NetBird account role type: string + status: + description: User's status + type: string + enum: [ "active","invited","disabled" ] auto_groups: description: Groups to auto-assign to peers registered by this user type: array @@ -46,6 +50,7 @@ components: - name - role - auto_groups + - status UserRequest: type: object properties: @@ -60,6 +65,27 @@ components: required: - role - auto_groups + UserCreateRequest: + type: object + properties: + role: + description: User's NetBird account role + type: string + email: + description: User's Email to send invite to + type: string + name: + description: User's full name + type: string + auto_groups: + description: Groups to auto-assign to peers registered by this user + type: array + items: + type: string + required: + - role + - auto_groups + - email PeerMinimum: type: object properties: @@ -499,6 +525,33 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/: + post: + summary: Create a User (invite) + tags: [ Users] + security: + - BearerAuth: [ ] + requestBody: + description: User invite information + content: + 'application/json': + schema: + $ref: '#/components/schemas/UserCreateRequest' + responses: + '200': + description: A User object + content: + application/json: + schema: + $ref: '#/components/schemas/User' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/users/{id}: put: summary: Update information about a User diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 4e256666c..d82f1254c 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -87,6 +87,13 @@ const ( RulePatchOperationPathSources RulePatchOperationPath = "sources" ) +// Defines values for UserStatus. +const ( + UserStatusActive UserStatus = "active" + UserStatusDisabled UserStatus = "disabled" + UserStatusInvited UserStatus = "invited" +) + // Group defines model for Group. type Group struct { // Id Group ID @@ -466,6 +473,27 @@ type User struct { // Role User's NetBird account role Role string `json:"role"` + + // Status User's status + Status UserStatus `json:"status"` +} + +// UserStatus User's status +type UserStatus string + +// UserCreateRequest defines model for UserCreateRequest. +type UserCreateRequest struct { + // AutoGroups Groups to auto-assign to peers registered by this user + AutoGroups []string `json:"auto_groups"` + + // Email User's Email to send invite to + Email string `json:"email"` + + // Name User's full name + Name *string `json:"name,omitempty"` + + // Role User's NetBird account role + Role string `json:"role"` } // UserRequest defines model for UserRequest. @@ -586,5 +614,8 @@ type PostApiSetupKeysJSONRequestBody = SetupKeyRequest // PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType. type PutApiSetupKeysIdJSONRequestBody = SetupKeyRequest +// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType. +type PostApiUsersJSONRequestBody = UserCreateRequest + // PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType. type PutApiUsersIdJSONRequestBody = UserRequest diff --git a/management/server/http/groups_test.go b/management/server/http/groups_test.go index 115da9b66..aab18df49 100644 --- a/management/server/http/groups_test.go +++ b/management/server/http/groups_test.go @@ -67,14 +67,14 @@ func initGroupTestData(groups ...*server.Group) *Groups { } return nil, fmt.Errorf("peer not found") }, - GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", Peers: TestPeers, Groups: map[string]*server.Group{ - "id-existed": &server.Group{ID: "id-existed", Peers: []string{"A", "B"}}, - "id-all": &server.Group{ID: "id-all", Name: "All"}}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, + "id-all": {ID: "id-all", Name: "All"}}, }, nil }, }, diff --git a/management/server/http/handler.go b/management/server/http/handler.go index a4aeb8276..9e29a5dc3 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -41,6 +41,7 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience Methods("GET", "PUT", "DELETE", "OPTIONS") apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") + apiHandler.HandleFunc("/api/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS") diff --git a/management/server/http/nameservers_test.go b/management/server/http/nameservers_test.go index 06438415d..c1c55a352 100644 --- a/management/server/http/nameservers_test.go +++ b/management/server/http/nameservers_test.go @@ -104,7 +104,7 @@ func initNameserversTestData() *Nameservers { } return nsGroupToUpdate, nil }, - GetAccountWithAuthorizationClaimsFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { return testingNSAccount, nil }, }, diff --git a/management/server/http/peers_test.go b/management/server/http/peers_test.go index 7f29dde99..45339f338 100644 --- a/management/server/http/peers_test.go +++ b/management/server/http/peers_test.go @@ -19,7 +19,7 @@ import ( func initTestMetaData(peer ...*server.Peer) *Peers { return &Peers{ accountManager: &mock_server.MockAccountManager{ - GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", diff --git a/management/server/http/routes_test.go b/management/server/http/routes_test.go index 459e1cc83..fe4f1bddf 100644 --- a/management/server/http/routes_test.go +++ b/management/server/http/routes_test.go @@ -120,7 +120,7 @@ func initRoutesTestData() *Routes { } return routeToUpdate, nil }, - GetAccountWithAuthorizationClaimsFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { return testingAccount, nil }, }, diff --git a/management/server/http/rules_test.go b/management/server/http/rules_test.go index 55033812a..e5c62324e 100644 --- a/management/server/http/rules_test.go +++ b/management/server/http/rules_test.go @@ -66,14 +66,14 @@ func initRulesTestData(rules ...*server.Rule) *Rules { } return &rule, nil }, - GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}}, Groups: map[string]*server.Group{ - "F": &server.Group{ID: "F"}, - "G": &server.Group{ID: "G"}, + "F": {ID: "F"}, + "G": {ID: "G"}, }, }, nil }, diff --git a/management/server/http/setupkeys_test.go b/management/server/http/setupkeys_test.go index c37702e47..239e196df 100644 --- a/management/server/http/setupkeys_test.go +++ b/management/server/http/setupkeys_test.go @@ -31,7 +31,7 @@ const ( func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys { return &SetupKeys{ accountManager: &mock_server.MockAccountManager{ - GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { return &server.Account{ Id: testAccountID, Domain: "hotmail.com", diff --git a/management/server/http/users.go b/management/server/http/users.go index 83b253824..3242c96e7 100644 --- a/management/server/http/users.go +++ b/management/server/http/users.go @@ -5,12 +5,11 @@ import ( "fmt" "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/http/api" + log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "net/http" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" ) @@ -82,6 +81,50 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } +// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite). +func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "", http.StatusNotFound) + } + + account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + if err != nil { + log.Error(err) + } + + req := &api.PostApiUsersJSONRequestBody{} + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { + http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest) + return + } + + newUser, err := h.accountManager.CreateUser(account.Id, &server.UserInfo{ + Email: req.Email, + Name: *req.Name, + Role: req.Role, + AutoGroups: req.AutoGroups, + }) + if err != nil { + if e, ok := server.FromError(err); ok { + switch e.Type() { + case server.UserAlreadyExists: + http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed) + return + default: + } + } + http.Error(w, "failed to invite", http.StatusInternalServerError) + return + } + writeJSONObject(w, toUserResponse(newUser)) +} + // GetUsers returns a list of users of the account this user belongs to. // It also gathers additional user data (like email and name) from the IDP manager. func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { @@ -101,7 +144,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { return } - users := []*api.User{} + users := make([]*api.User, 0) for _, r := range data { users = append(users, toUserResponse(r)) } @@ -116,11 +159,22 @@ func toUserResponse(user *server.UserInfo) *api.User { autoGroups = []string{} } + var userStatus api.UserStatus + switch user.Status { + case "active": + userStatus = api.UserStatusActive + case "invited": + userStatus = api.UserStatusInvited + default: + userStatus = api.UserStatusDisabled + } + return &api.User{ Id: user.ID, Name: user.Name, Email: user.Email, Role: user.Role, AutoGroups: autoGroups, + Status: userStatus, } } diff --git a/management/server/http/users_test.go b/management/server/http/users_test.go index 6c4412848..597274c70 100644 --- a/management/server/http/users_test.go +++ b/management/server/http/users_test.go @@ -16,7 +16,7 @@ import ( func initUsers(user ...*server.User) *UserHandler { return &UserHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { users := make(map[string]*server.User, 0) for _, u := range user { users[u.Id] = u diff --git a/management/server/http/util.go b/management/server/http/util.go index 4e13b0e58..084ff3d9b 100644 --- a/management/server/http/util.go +++ b/management/server/http/util.go @@ -60,7 +60,7 @@ func getJWTAccount(accountManager server.AccountManager, jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience) - account, err := accountManager.GetAccountWithAuthorizationClaims(jwtClaims) + account, err := accountManager.GetAccountFromToken(jwtClaims) if err != nil { return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) } diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index bd73343fc..b7b8e012e 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -7,7 +7,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/url" "strconv" @@ -54,6 +53,16 @@ type Auth0Credentials struct { mux sync.Mutex } +// createUserRequest is a user create request +type createUserRequest struct { + Email string `json:"email"` + Name string `json:"name"` + AppMeta AppMetadata `json:"app_metadata"` + Connection string `json:"connection"` + Password string `json:"password"` + VerifyEmail bool `json:"verify_email"` +} + // userExportJobRequest is a user export request struct type userExportJobRequest struct { Format string `json:"format"` @@ -87,12 +96,13 @@ type userExportJobStatusResponse struct { // auth0Profile represents an Auth0 user profile response type auth0Profile struct { - AccountID string `json:"wt_account_id"` - UserID string `json:"user_id"` - Name string `json:"name"` - Email string `json:"email"` - CreatedAt string `json:"created_at"` - LastLogin string `json:"last_login"` + AccountID string `json:"wt_account_id"` + PendingInvite bool `json:"wt_pending_invite"` + UserID string `json:"user_id"` + Name string `json:"name"` + Email string `json:"email"` + CreatedAt string `json:"created_at"` + LastLogin string `json:"last_login"` } // NewAuth0Manager creates a new instance of the Auth0Manager @@ -172,7 +182,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) { // parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) { jwtToken := JWTToken{} - body, err := ioutil.ReadAll(rawBody) + body, err := io.ReadAll(rawBody) if err != nil { return jwtToken, err } @@ -230,7 +240,7 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) { return c.jwtToken, nil } -func batchRequestUsersURL(authIssuer, accountID string, page int) (string, url.Values, error) { +func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (string, url.Values, error) { u, err := url.Parse(authIssuer + "/api/v2/users") if err != nil { return "", nil, err @@ -238,6 +248,7 @@ func batchRequestUsersURL(authIssuer, accountID string, page int) (string, url.V q := u.Query() q.Set("page", strconv.Itoa(page)) q.Set("search_engine", "v3") + q.Set("per_page", strconv.Itoa(perPage)) q.Set("q", "app_metadata.wt_account_id:"+accountID) u.RawQuery = q.Encode() @@ -259,8 +270,9 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { // https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations // auth0 limitation of 1000 users via this endpoint + resultsPerPage := 50 for page := 0; page < 20; page++ { - reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page) + reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page, resultsPerPage) if err != nil { return nil, err } @@ -283,30 +295,31 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { return nil, err } + if res.StatusCode != 200 { + return nil, fmt.Errorf("failed requesting user data from IdP %s", string(body)) + } + var batch []UserData err = json.Unmarshal(body, &batch) if err != nil { return nil, err } - log.Debugf("requested batch; %v", batch) + log.Debugf("returned user batch for accountID %s on page %d, %v", accountID, page, batch) err = res.Body.Close() if err != nil { return nil, err } - if res.StatusCode != 200 { - return nil, fmt.Errorf("unable to request UserData from auth0, statusCode %d", res.StatusCode) - } - - if len(batch) == 0 { - return list, nil - } - for user := range batch { list = append(list, &batch[user]) } + + if len(batch) == 0 || len(batch) < resultsPerPage { + log.Debugf("finished loading users for accountID %s", accountID) + return list, nil + } } return list, nil @@ -367,14 +380,12 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta reqURL := am.authIssuer + "/api/v2/users/" + userID - data, err := am.helper.Marshal(appMetadata) + data, err := am.helper.Marshal(map[string]any{"app_metadata": appMetadata}) if err != nil { return err } - payloadString := fmt.Sprintf("{\"app_metadata\": %s}", string(data)) - - payload := strings.NewReader(payloadString) + payload := strings.NewReader(string(data)) req, err := http.NewRequest("PATCH", reqURL, payload) if err != nil { @@ -383,7 +394,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("content-type", "application/json") - log.Debugf("updating metadata for user %s", userID) + log.Debugf("updating IdP metadata for user %s", userID) res, err := am.httpClient.Do(req) if err != nil { @@ -404,6 +415,27 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta return nil } +func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) { + req := &createUserRequest{ + Email: email, + Name: name, + AppMeta: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: true, + }, + Connection: "Username-Password-Authentication", + Password: GeneratePassword(8, 1, 1, 1), + VerifyEmail: true, + } + + str, err := json.Marshal(req) + if err != nil { + return "", err + } + + return string(str), nil +} + func buildUserExportRequest() (string, error) { req := &userExportJobRequest{} fields := make([]map[string]string, 0) @@ -417,6 +449,11 @@ func buildUserExportRequest() (string, error) { "export_as": "wt_account_id", }) + fields = append(fields, map[string]string{ + "name": "app_metadata.wt_pending_invite", + "export_as": "wt_pending_invite", + }) + req.Format = "json" req.Fields = fields @@ -428,28 +465,39 @@ func buildUserExportRequest() (string, error) { return string(str), nil } -// GetAllAccounts gets all registered accounts with corresponding user data. -// It returns a list of users indexed by accountID. -func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { +func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { jwtToken, err := am.credentials.Authenticate() if err != nil { return nil, err } - reqURL := am.authIssuer + "/api/v2/jobs/users-exports" + reqURL := am.authIssuer + endpoint + payload := strings.NewReader(payloadStr) + + req, err := http.NewRequest("POST", reqURL, payload) + if err != nil { + return nil, err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + return req, nil + +} + +// GetAllAccounts gets all registered accounts with corresponding user data. +// It returns a list of users indexed by accountID. +func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { payloadString, err := buildUserExportRequest() if err != nil { return nil, err } - payload := strings.NewReader(payloadString) - exportJobReq, err := http.NewRequest("POST", reqURL, payload) + exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString) if err != nil { return nil, err } - exportJobReq.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) - exportJobReq.Header.Add("content-type", "application/json") jobResp, err := am.httpClient.Do(exportJobReq) if err != nil { @@ -469,7 +517,7 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { var exportJobResp userExportJobResponse - body, err := ioutil.ReadAll(jobResp.Body) + body, err := io.ReadAll(jobResp.Body) if err != nil { log.Debugf("Coudln't read export job response; %v", err) return nil, err @@ -500,6 +548,82 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { return nil, fmt.Errorf("failed extracting user profiles from auth0") } +// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list. +// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with +// the same email but different connections that are considered as separate accounts (e.g., Google and username/password). +func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) { + jwtToken, err := am.credentials.Authenticate() + if err != nil { + return nil, err + } + reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + email + body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken) + if err != nil { + return nil, err + } + + userResp := []*UserData{} + + err = am.helper.Unmarshal(body, &userResp) + if err != nil { + log.Debugf("Coudln't unmarshal export job response; %v", err) + return nil, err + } + + return userResp, nil +} + +// CreateUser creates a new user in Auth0 Idp and sends an invite +func (am *Auth0Manager) CreateUser(email string, name string, accountID string) (*UserData, error) { + + payloadString, err := buildCreateUserRequestPayload(email, name, accountID) + if err != nil { + return nil, err + } + req, err := am.createPostRequest("/api/v2/users", payloadString) + if err != nil { + return nil, err + } + + resp, err := am.httpClient.Do(req) + if err != nil { + log.Debugf("Couldn't get job response %v", err) + return nil, err + } + + defer func() { + err = resp.Body.Close() + if err != nil { + log.Errorf("error while closing create user response body: %v", err) + } + }() + if !(resp.StatusCode == 200 || resp.StatusCode == 201) { + return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode) + } + + var createResp UserData + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Debugf("Coudln't read export job response; %v", err) + return nil, err + } + + err = am.helper.Unmarshal(body, &createResp) + if err != nil { + log.Debugf("Coudln't unmarshal export job response; %v", err) + return nil, err + } + + if createResp.ID == "" { + return nil, fmt.Errorf("couldn't create user: response %v", resp) + } + + log.Debugf("created user %s in account %s", createResp.ID, accountID) + + return &createResp, nil +} + // checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // If the status is "completed", then return the downloadLink func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { @@ -572,6 +696,10 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us ID: profile.UserID, Name: profile.Name, Email: profile.Email, + AppMetadata: AppMetadata{ + WTAccountID: profile.AccountID, + WTPendingInvite: profile.PendingInvite, + }, }) } } @@ -605,7 +733,7 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { return nil, err } diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index bded9bf5b..d1eb9fab7 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" "github.com/stretchr/testify/require" - "io/ioutil" + "io" "net/http" "strings" "testing" @@ -22,13 +22,13 @@ type mockHTTPClient struct { } func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { - body, err := ioutil.ReadAll(req.Body) + body, err := io.ReadAll(req.Body) if err == nil { c.reqBody = string(body) } return &http.Response{ StatusCode: c.code, - Body: ioutil.NopCloser(strings.NewReader(c.resBody)), + Body: io.NopCloser(strings.NewReader(c.resBody)), }, c.err } @@ -130,7 +130,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) { t.Fatal(err) } } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) assert.NoError(t, err, "unable to read the response body") jwtToken := JWTToken{} @@ -178,7 +178,7 @@ func TestAuth0_ParseRequestJWTResponse(t *testing.T) { for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} { t.Run(testCase.name, func(t *testing.T) { - rawBody := ioutil.NopCloser(strings.NewReader(testCase.inputResBody)) + rawBody := io.NopCloser(strings.NewReader(testCase.inputResBody)) config := Auth0ClientConfig{} @@ -320,7 +320,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { exp := 15 token := newTestJWT(t, exp) - appMetadata := AppMetadata{WTAccountId: "ok"} + appMetadata := AppMetadata{WTAccountID: "ok"} updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{ name: "Bad Authentication", @@ -340,7 +340,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ name: "Bad Status Code", inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), - expectedReqBody: fmt.Sprintf("{\"app_metadata\": {\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountId), + expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID), appMetadata: appMetadata, statusCode: 400, helper: JsonParser{}, @@ -363,7 +363,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ name: "Good request", inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), - expectedReqBody: fmt.Sprintf("{\"app_metadata\": {\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountId), + expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID), appMetadata: appMetadata, statusCode: 200, helper: JsonParser{}, diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index f6bd94224..f43540b31 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -13,6 +13,8 @@ type Manager interface { GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) GetAccount(accountId string) ([]*UserData, error) GetAllAccounts() (map[string][]*UserData, error) + CreateUser(email string, name string, accountID string) (*UserData, error) + GetUserByEmail(email string) ([]*UserData, error) } // Config an idp configuration struct to be loaded from management server's config file @@ -38,16 +40,18 @@ type ManagerHelper interface { } type UserData struct { - Email string `json:"email"` - Name string `json:"name"` - ID string `json:"user_id"` + Email string `json:"email"` + Name string `json:"name"` + ID string `json:"user_id"` + AppMetadata AppMetadata `json:"app_metadata"` } // AppMetadata user app metadata to associate with a profile type AppMetadata struct { - // Wiretrustee account id to update in the IDP + // WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP // maps to wt_account_id when json.marshal - WTAccountId string `json:"wt_account_id"` + WTAccountID string `json:"wt_account_id,omitempty"` + WTPendingInvite bool `json:"wt_pending_invite"` } // JWTToken a JWT object that holds information of a token diff --git a/management/server/idp/util.go b/management/server/idp/util.go index 3d963dfc5..2401a7207 100644 --- a/management/server/idp/util.go +++ b/management/server/idp/util.go @@ -1,6 +1,18 @@ package idp -import "encoding/json" +import ( + "encoding/json" + "math/rand" + "strings" +) + +var ( + lowerCharSet = "abcdedfghijklmnopqrst" + upperCharSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + specialCharSet = "!@#$%&*" + numberSet = "0123456789" + allCharSet = lowerCharSet + upperCharSet + specialCharSet + numberSet +) type JsonParser struct{} @@ -11,3 +23,37 @@ func (JsonParser) Marshal(v interface{}) ([]byte, error) { func (JsonParser) Unmarshal(data []byte, v interface{}) error { return json.Unmarshal(data, v) } + +// GeneratePassword generates user password +func GeneratePassword(passwordLength, minSpecialChar, minNum, minUpperCase int) string { + var password strings.Builder + + //Set special character + for i := 0; i < minSpecialChar; i++ { + random := rand.Intn(len(specialCharSet)) + password.WriteString(string(specialCharSet[random])) + } + + //Set numeric + for i := 0; i < minNum; i++ { + random := rand.Intn(len(numberSet)) + password.WriteString(string(numberSet[random])) + } + + //Set uppercase + for i := 0; i < minUpperCase; i++ { + random := rand.Intn(len(upperCharSet)) + password.WriteString(string(upperCharSet[random])) + } + + remainingLength := passwordLength - minSpecialChar - minNum - minUpperCase + for i := 0; i < remainingLength; i++ { + random := rand.Intn(len(allCharSet)) + password.WriteString(string(allCharSet[random])) + } + inRune := []rune(password.String()) + rand.Shuffle(len(inRune), func(i, j int) { + inRune[i], inRune[j] = inRune[j], inRune[i] + }) + return string(inRune) +} diff --git a/management/server/management_test.go b/management/server/management_test.go index b66b7a827..954a039da 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -2,7 +2,6 @@ package server_test import ( "context" - "io/ioutil" "math/rand" "net" "os" @@ -45,7 +44,7 @@ var _ = Describe("Management service", func() { level, _ := log.ParseLevel("Debug") log.SetLevel(level) var err error - dataDir, err = ioutil.TempDir("", "wiretrustee_mgmt_test_tmp_*") + dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*") Expect(err).NotTo(HaveOccurred()) err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json")) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 099e696a9..4a6099726 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -11,55 +11,56 @@ import ( ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) - GetAccountByUserFunc func(userId string) (*server.Account, error) - CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error) - GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error) - GetAccountByIdFunc func(accountId string) (*server.Account, error) - GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) - GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) - IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) - AccountExistsFunc func(accountId string) (*bool, error) - GetPeerFunc func(peerKey string) (*server.Peer, error) - MarkPeerConnectedFunc func(peerKey string, connected bool) error - RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error) - DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) - GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) - GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) - GetPeerNetworkFunc func(peerKey string) (*server.Network, error) - AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, error) - GetGroupFunc func(accountID, groupID string) (*server.Group, error) - SaveGroupFunc func(accountID string, group *server.Group) error - UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) - DeleteGroupFunc func(accountID, groupID string) error - ListGroupsFunc func(accountID string) ([]*server.Group, error) - GroupAddPeerFunc func(accountID, groupID, peerKey string) error - GroupDeletePeerFunc func(accountID, groupID, peerKey string) error - GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) - GetRuleFunc func(accountID, ruleID string) (*server.Rule, error) - SaveRuleFunc func(accountID string, rule *server.Rule) error - UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error) - DeleteRuleFunc func(accountID, ruleID string) error - ListRulesFunc func(accountID string) ([]*server.Rule, error) - GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error) - UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error - UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error - UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error) - CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) - GetRouteFunc func(accountID, routeID string) (*route.Route, error) - SaveRouteFunc func(accountID string, route *route.Route) error - UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) - DeleteRouteFunc func(accountID, routeID string) error - ListRoutesFunc func(accountID string) ([]*route.Route, error) - SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error) - ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error) - SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error) - GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error - UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) - DeleteNameServerGroupFunc func(accountID, nsGroupID string) error - ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) + GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) + GetAccountByUserFunc func(userId string) (*server.Account, error) + CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error) + GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error) + GetAccountByIdFunc func(accountId string) (*server.Account, error) + GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) + IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) + AccountExistsFunc func(accountId string) (*bool, error) + GetPeerFunc func(peerKey string) (*server.Peer, error) + MarkPeerConnectedFunc func(peerKey string, connected bool) error + RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error) + DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) + GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) + GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) + GetPeerNetworkFunc func(peerKey string) (*server.Network, error) + AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, error) + GetGroupFunc func(accountID, groupID string) (*server.Group, error) + SaveGroupFunc func(accountID string, group *server.Group) error + UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) + DeleteGroupFunc func(accountID, groupID string) error + ListGroupsFunc func(accountID string) ([]*server.Group, error) + GroupAddPeerFunc func(accountID, groupID, peerKey string) error + GroupDeletePeerFunc func(accountID, groupID, peerKey string) error + GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) + GetRuleFunc func(accountID, ruleID string) (*server.Rule, error) + SaveRuleFunc func(accountID string, rule *server.Rule) error + UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error) + DeleteRuleFunc func(accountID, ruleID string) error + ListRulesFunc func(accountID string) ([]*server.Rule, error) + GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error) + UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error + UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error + UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error) + CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) + GetRouteFunc func(accountID, routeID string) (*route.Route, error) + SaveRouteFunc func(accountID string, route *route.Route) error + UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) + DeleteRouteFunc func(accountID, routeID string) error + ListRoutesFunc func(accountID string) ([]*route.Route, error) + SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error) + ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error) + SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error) + GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error + UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) + DeleteNameServerGroupFunc func(accountID, nsGroupID string) error + ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) + CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error) + GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -126,19 +127,6 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId( ) } -// GetAccountWithAuthorizationClaims mock implementation of GetAccountWithAuthorizationClaims from server.AccountManager interface -func (am *MockAccountManager) GetAccountWithAuthorizationClaims( - claims jwtclaims.AuthorizationClaims, -) (*server.Account, error) { - if am.GetAccountWithAuthorizationClaimsFunc != nil { - return am.GetAccountWithAuthorizationClaimsFunc(claims) - } - return nil, status.Errorf( - codes.Unimplemented, - "method GetAccountWithAuthorizationClaims is not implemented", - ) -} - // AccountExists mock implementation of AccountExists from server.AccountManager interface func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) { if am.AccountExistsFunc != nil { @@ -485,3 +473,19 @@ func (am *MockAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.N } return nil, nil } + +// CreateUser mocks CreateUser of the AccountManager interface +func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserInfo) (*server.UserInfo, error) { + if am.CreateUserFunc != nil { + return am.CreateUserFunc(accountID, invite) + } + return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") +} + +// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface +func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + if am.GetAccountFromTokenFunc != nil { + return am.GetAccountFromTokenFunc(claims) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") +} diff --git a/management/server/user.go b/management/server/user.go index 391debfb8..d256238f9 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -14,6 +14,10 @@ const ( UserRoleAdmin UserRole = "admin" UserRoleUser UserRole = "user" UserRoleUnknown UserRole = "unknown" + + UserStatusActive UserStatus = "active" + UserStatusDisabled UserStatus = "disabled" + UserStatusInvited UserStatus = "invited" ) // StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown @@ -28,7 +32,10 @@ func StrRoleToUserRole(strRole string) UserRole { } } -// UserRole is the role of the User +// UserStatus is the status of a User +type UserStatus string + +// UserRole is the role of a User type UserRole string // User represents a user of the system @@ -53,24 +60,31 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) { Name: "", Role: string(u.Role), AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), }, nil } if userData.ID != u.Id { return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) } + userStatus := UserStatusActive + if userData.AppMetadata.WTPendingInvite { + userStatus = UserStatusInvited + } + return &UserInfo{ ID: u.Id, Email: userData.Email, Name: userData.Name, Role: string(u.Role), AutoGroups: autoGroups, + Status: string(userStatus), }, nil } // Copy the user func (u *User) Copy() *User { - autoGroups := []string{} + autoGroups := make([]string, 0) autoGroups = append(autoGroups, u.AutoGroups...) return &User{ Id: u.Id, @@ -98,6 +112,70 @@ func NewAdminUser(id string) *User { return NewUser(id, UserRoleAdmin) } +// CreateUser creates a new user under the given account. Effectively this is a user invite. +func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) { + am.mux.Lock() + defer am.mux.Unlock() + + if am.idpManager == nil { + return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites") + } + + if invite == nil { + return nil, fmt.Errorf("provided user update is nil") + } + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, Errorf(AccountNotFound, "account %s doesn't exist", accountID) + } + + // check if the user is already registered with this email => reject + user, err := am.lookupUserInCacheByEmail(invite.Email, accountID) + if err != nil { + return nil, err + } + + if user != nil { + return nil, Errorf(UserAlreadyExists, "user has an existing account") + } + + users, err := am.idpManager.GetUserByEmail(invite.Email) + if err != nil { + return nil, err + } + + if len(users) > 0 { + return nil, Errorf(UserAlreadyExists, "user has an existing account") + } + + idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID) + if err != nil { + return nil, err + } + + role := StrRoleToUserRole(invite.Role) + newUser := &User{ + Id: idpUser.ID, + Role: role, + AutoGroups: invite.AutoGroups, + } + account.Users[idpUser.ID] = newUser + + err = am.Store.SaveAccount(account) + if err != nil { + return nil, err + } + + _, err = am.refreshCache(account.Id) + if err != nil { + return nil, err + } + + return newUser.toUserInfo(idpUser) + +} + // SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. // Only User.AutoGroups field is allowed to be updated for now. func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) { @@ -138,10 +216,13 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User } if !isNil(am.idpManager) { - userData, err := am.lookupUserInCache(newUser, accountID) + userData, err := am.lookupUserInCache(newUser.Id, account) if err != nil { return nil, err } + if userData == nil { + return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", newUser.Id) + } return newUser.toUserInfo(userData) } return newUser.toUserInfo(nil) @@ -194,7 +275,7 @@ func (am *DefaultAccountManager) GetAccountByUser(userId string) (*Account, erro // IsUserAdmin flag for current user authenticated by JWT token func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { - account, err := am.GetAccountWithAuthorizationClaims(claims) + account, err := am.GetAccountFromToken(claims) if err != nil { return false, fmt.Errorf("get account: %v", err) } @@ -216,7 +297,11 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI queriedUsers := make([]*idp.UserData, 0) if !isNil(am.idpManager) { - queriedUsers, err = am.lookupCache(account.Users, accountID) + users := make(map[string]struct{}, len(account.Users)) + for _, user := range account.Users { + users[user.Id] = struct{}{} + } + queriedUsers, err = am.lookupCache(users, accountID) if err != nil { return nil, err }