diff --git a/pkg/common/telemetry/server/datastore/wrapper.go b/pkg/common/telemetry/server/datastore/wrapper.go index 14bc4244e3..e84dd6778c 100644 --- a/pkg/common/telemetry/server/datastore/wrapper.go +++ b/pkg/common/telemetry/server/datastore/wrapper.go @@ -131,6 +131,12 @@ func (w metricsWrapper) FetchAttestedNode(ctx context.Context, spiffeID string) return w.ds.FetchAttestedNode(ctx, spiffeID) } +func (w metricsWrapper) FetchAttestedNodes(ctx context.Context, spiffeIDs []string) (_ map[string]*common.AttestedNode, err error) { + callCounter := StartFetchNodeCall(w.m) + defer callCounter.Done(&err) + return w.ds.FetchAttestedNodes(ctx, spiffeIDs) +} + func (w metricsWrapper) FetchAttestedNodeEvent(ctx context.Context, eventID uint) (_ *datastore.AttestedNodeEvent, err error) { callCounter := StartFetchAttestedNodeEventCall(w.m) defer callCounter.Done(&err) diff --git a/pkg/common/telemetry/server/datastore/wrapper_test.go b/pkg/common/telemetry/server/datastore/wrapper_test.go index b5ac06d865..07479c0b2c 100644 --- a/pkg/common/telemetry/server/datastore/wrapper_test.go +++ b/pkg/common/telemetry/server/datastore/wrapper_test.go @@ -117,6 +117,10 @@ func TestWithMetrics(t *testing.T) { key: "datastore.node.fetch", methodName: "FetchAttestedNode", }, + { + key: "datastore.node.fetch", + methodName: "FetchAttestedNodes", + }, { key: "datastore.node_event.fetch", methodName: "FetchAttestedNodeEvent", @@ -424,6 +428,10 @@ func (ds *fakeDataStore) FetchAttestedNode(context.Context, string) (*common.Att return &common.AttestedNode{}, ds.err } +func (ds *fakeDataStore) FetchAttestedNodes(context.Context, []string) (map[string]*common.AttestedNode, error) { + return map[string]*common.AttestedNode{}, ds.err +} + func (ds *fakeDataStore) FetchAttestedNodeEvent(context.Context, uint) (*datastore.AttestedNodeEvent, error) { return &datastore.AttestedNodeEvent{}, ds.err } diff --git a/pkg/server/datastore/datastore.go b/pkg/server/datastore/datastore.go index 63884bdb61..d881b6e092 100644 --- a/pkg/server/datastore/datastore.go +++ b/pkg/server/datastore/datastore.go @@ -52,6 +52,9 @@ type DataStore interface { CreateAttestedNode(context.Context, *common.AttestedNode) (*common.AttestedNode, error) DeleteAttestedNode(ctx context.Context, spiffeID string) (*common.AttestedNode, error) FetchAttestedNode(ctx context.Context, spiffeID string) (*common.AttestedNode, error) + // FetchAttestedNodes fetches the given nodes (with selectors) keyed by SPIFFE ID. + // IDs with no existing node are omitted from the map and may be treated as deleted. + FetchAttestedNodes(ctx context.Context, spiffeIDs []string) (map[string]*common.AttestedNode, error) ListAttestedNodes(context.Context, *ListAttestedNodesRequest) (*ListAttestedNodesResponse, error) UpdateAttestedNode(context.Context, *common.AttestedNode, *common.AttestedNodeMask) (*common.AttestedNode, error) PruneAttestedExpiredNodes(ctx context.Context, expiredBefore time.Time, includeNonReattestable bool) error @@ -161,6 +164,7 @@ type ListAttestedNodesRequest struct { ByBanned *bool ByExpiresBefore time.Time BySelectorMatch *BySelectors + BySpiffeIDs []string FetchSelectors bool Pagination *Pagination ByCanReattest *bool diff --git a/pkg/server/datastore/sqlstore/sqlstore.go b/pkg/server/datastore/sqlstore/sqlstore.go index e7618f9750..107f516201 100644 --- a/pkg/server/datastore/sqlstore/sqlstore.go +++ b/pkg/server/datastore/sqlstore/sqlstore.go @@ -317,6 +317,27 @@ func (ds *Plugin) FetchAttestedNode(ctx context.Context, spiffeID string) (attes return attestedNode, nil } +// FetchAttestedNodes fetches existing attested nodes by SPIFFE IDs, including their selectors +func (ds *Plugin) FetchAttestedNodes(ctx context.Context, spiffeIDs []string) (map[string]*common.AttestedNode, error) { + nodesMap := make(map[string]*common.AttestedNode) + if len(spiffeIDs) == 0 { + return nodesMap, nil + } + + resp, err := listAttestedNodes(ctx, ds.db, ds.log, &datastore.ListAttestedNodesRequest{ + BySpiffeIDs: spiffeIDs, + FetchSelectors: true, + }) + if err != nil { + return nil, err + } + + for _, node := range resp.Nodes { + nodesMap[node.SpiffeId] = node + } + return nodesMap, nil +} + // CountAttestedNodes counts all attested nodes func (ds *Plugin) CountAttestedNodes(ctx context.Context, req *datastore.CountAttestedNodesRequest) (count int32, err error) { if countAttestedNodesHasFilters(req) { @@ -2048,6 +2069,14 @@ func buildListAttestedNodesQueryCTE(req *datastore.ListAttestedNodesRequest, dbT } } + // Filter by a set of SPIFFE IDs + if len(req.BySpiffeIDs) > 0 { + builder.WriteString("\t\tAND spiffe_id IN (") + builder.WriteString(buildQuestions(req.BySpiffeIDs)) + builder.WriteString(")\n") + args = append(args, buildArgs(req.BySpiffeIDs)...) + } + builder.WriteString(")") // Fetch all selectors from filtered entries if fetchSelectors { @@ -2281,6 +2310,14 @@ FROM attested_node_entries N builder.WriteString("\t\tAND can_reattest = false\n") } } + + // Filter by a set of SPIFFE IDs + if len(req.BySpiffeIDs) > 0 { + builder.WriteString(" AND N.spiffe_id IN (") + builder.WriteString(buildQuestions(req.BySpiffeIDs)) + builder.WriteString(")") + args = append(args, buildArgs(req.BySpiffeIDs)...) + } return nil } diff --git a/pkg/server/datastore/sqlstore/sqlstore_test.go b/pkg/server/datastore/sqlstore/sqlstore_test.go index f36c923f96..ccc48aa347 100644 --- a/pkg/server/datastore/sqlstore/sqlstore_test.go +++ b/pkg/server/datastore/sqlstore/sqlstore_test.go @@ -930,6 +930,80 @@ func (s *PluginSuite) TestFetchAttestedNodeMissing() { s.Require().Nil(attestedNode) } +func (s *PluginSuite) TestFetchAttestedNodes() { + createNode := func(spiffeID string, selectors []*common.Selector) *common.AttestedNode { + node, err := s.ds.CreateAttestedNode(ctx, &common.AttestedNode{ + SpiffeId: spiffeID, + AttestationDataType: "aws-tag", + CertSerialNumber: "badcafe", + CertNotAfter: time.Now().Add(time.Hour).Unix(), + }) + s.Require().NoError(err) + s.setNodeSelectors(spiffeID, selectors) + node.Selectors = selectors + return node + } + + node1 := createNode("spiffe://example.org/node1", []*common.Selector{{Type: "a", Value: "1"}}) + node2 := createNode("spiffe://example.org/node2", []*common.Selector{{Type: "b", Value: "2"}}) + node3 := createNode("spiffe://example.org/node3", []*common.Selector{{Type: "c", Value: "3"}}) + + // Create a node and then delete it so we can test it doesn't get returned with the fetch + node4 := createNode("spiffe://example.org/node4", []*common.Selector{{Type: "d", Value: "4"}}) + deletedNode, err := s.ds.DeleteAttestedNode(ctx, node4.SpiffeId) + s.Require().NoError(err) + s.Require().NotNil(deletedNode) + + for _, tt := range []struct { + name string + nodes []*common.AttestedNode + deletedSpiffeID string + }{ + { + name: "No nodes", + }, + { + name: "Nodes 1 and 2", + nodes: []*common.AttestedNode{node1, node2}, + }, + { + name: "Nodes 1, 2, and 3", + nodes: []*common.AttestedNode{node1, node2, node3}, + }, + { + name: "Deleted node", + nodes: []*common.AttestedNode{node2, node3}, + deletedSpiffeID: deletedNode.SpiffeId, + }, + } { + s.T().Run(tt.name, func(t *testing.T) { + spiffeIDs := make([]string, 0, len(tt.nodes)) + for _, node := range tt.nodes { + spiffeIDs = append(spiffeIDs, node.SpiffeId) + } + fetchedNodes, err := s.ds.FetchAttestedNodes(ctx, append(spiffeIDs, tt.deletedSpiffeID)) + s.Require().NoError(err) + + // Make sure all nodes we want to fetch are present, including selectors. + s.Require().Equal(len(tt.nodes), len(fetchedNodes)) + for _, node := range tt.nodes { + fetchedNode, ok := fetchedNodes[node.SpiffeId] + s.Require().True(ok) + s.RequireProtoEqual(node, fetchedNode) + } + + // Make sure any deleted nodes are not present. + _, ok := fetchedNodes[tt.deletedSpiffeID] + s.Require().False(ok) + }) + } + + // An empty request returns an empty map. + fetchedNodes, err := s.ds.FetchAttestedNodes(ctx, nil) + s.Require().NoError(err) + s.Require().Empty(fetchedNodes) +} + func (s *PluginSuite) TestListAttestedNodes() { // Connection is never used, each test creates a connection to a different database s.ds.Close() diff --git a/pkg/server/endpoints/authorized_entryfetcher.go b/pkg/server/endpoints/authorized_entryfetcher.go index 4d3a50c315..7005f34b83 100644 --- a/pkg/server/endpoints/authorized_entryfetcher.go +++ b/pkg/server/endpoints/authorized_entryfetcher.go @@ -171,7 +171,7 @@ func (a *AuthorizedEntryFetcherEvents) buildCache(ctx context.Context) error { return err } - attestedNodes, err := buildAttestedNodesCache(ctx, a.c.log, a.c.metrics, a.c.ds, a.c.clk, cache, a.c.nodeCache, a.c.cacheReloadInterval, a.c.eventTimeout) + attestedNodes, err := buildAttestedNodesCache(ctx, a.c.log, a.c.metrics, a.c.ds, a.c.clk, cache, a.c.nodeCache, pageSize, a.c.cacheReloadInterval, a.c.eventTimeout) if err != nil { return err } diff --git a/pkg/server/endpoints/authorized_entryfetcher_attested_nodes.go b/pkg/server/endpoints/authorized_entryfetcher_attested_nodes.go index 485bf718b4..037129bf82 100644 --- a/pkg/server/endpoints/authorized_entryfetcher_attested_nodes.go +++ b/pkg/server/endpoints/authorized_entryfetcher_attested_nodes.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "maps" + "slices" "time" "github.com/andres-erbsen/clock" @@ -35,6 +37,7 @@ type attestedNodes struct { eventTracker *eventTracker eventTimeout time.Duration + pageSize int32 fetchNodes map[string]struct{} @@ -158,7 +161,11 @@ func (a *attestedNodes) loadCache(ctx context.Context, cache *authorizedentries. // buildAttestedNodesCache fetches all attested nodes and adds the unexpired ones to the cache. // It runs once at startup. -func buildAttestedNodesCache(ctx context.Context, log logrus.FieldLogger, metrics telemetry.Metrics, ds datastore.DataStore, clk clock.Clock, cache *authorizedentries.Cache, nodeCache *nodecache.Cache, cacheReloadInterval, eventTimeout time.Duration) (*attestedNodes, error) { +func buildAttestedNodesCache(ctx context.Context, log logrus.FieldLogger, metrics telemetry.Metrics, ds datastore.DataStore, clk clock.Clock, cache *authorizedentries.Cache, nodeCache *nodecache.Cache, pageSize int32, cacheReloadInterval, eventTimeout time.Duration) (*attestedNodes, error) { + if pageSize <= 0 { + return nil, fmt.Errorf("page size must be positive, got %d", pageSize) + } + pollPeriods := PollPeriods(cacheReloadInterval, eventTimeout) attestedNodes := &attestedNodes{ @@ -169,6 +176,7 @@ func buildAttestedNodesCache(ctx context.Context, log logrus.FieldLogger, metric log: log, metrics: metrics, eventTimeout: eventTimeout, + pageSize: pageSize, eventsBeforeFirst: make(map[uint]struct{}), fetchNodes: make(map[string]struct{}), @@ -211,34 +219,39 @@ func (a *attestedNodes) updateCache(ctx context.Context) error { } func (a *attestedNodes) updateCachedNodes(ctx context.Context) error { - for spiffeId := range a.fetchNodes { - node, err := a.ds.FetchAttestedNode(ctx, spiffeId) + spiffeIds := slices.Collect(maps.Keys(a.fetchNodes)) + for pageStart := 0; pageStart < len(spiffeIds); pageStart += int(a.pageSize) { + fetchNodes := a.fetchNodesPage(spiffeIds, pageStart) + nodes, err := a.ds.FetchAttestedNodes(ctx, fetchNodes) if err != nil { - continue + return err } - // Node was deleted - if node == nil { - a.nodeCache.RemoveAttestedNode(spiffeId) - a.cache.RemoveAgent(spiffeId) - delete(a.fetchNodes, spiffeId) - continue - } + for _, spiffeId := range fetchNodes { + node, ok := nodes[spiffeId] + // Node was deleted (absent from the response, or explicitly nil) + if !ok || node == nil { + a.nodeCache.RemoveAttestedNode(spiffeId) + a.cache.RemoveAgent(spiffeId) + delete(a.fetchNodes, spiffeId) + continue + } - selectors, err := a.ds.GetNodeSelectors(ctx, spiffeId, datastore.RequireCurrent) - if err != nil { - continue + agentExpiresAt := time.Unix(node.CertNotAfter, 0) + a.cache.UpdateAgent(node.SpiffeId, agentExpiresAt, api.ProtoFromSelectors(node.Selectors)) + a.nodeCache.UpdateAttestedNode(node) + delete(a.fetchNodes, spiffeId) } - node.Selectors = selectors - - agentExpiresAt := time.Unix(node.CertNotAfter, 0) - a.cache.UpdateAgent(node.SpiffeId, agentExpiresAt, api.ProtoFromSelectors(node.Selectors)) - a.nodeCache.UpdateAttestedNode(node) - delete(a.fetchNodes, spiffeId) } return nil } +// fetchNodesPage gets the range for the page starting at pageStart +func (a *attestedNodes) fetchNodesPage(spiffeIds []string, pageStart int) []string { + pageEnd := min(len(spiffeIds), pageStart+int(a.pageSize)) + return spiffeIds[pageStart:pageEnd] +} + func (a *attestedNodes) swapCache(cache *authorizedentries.Cache) { a.cache = cache a.fetchNodes = make(map[string]struct{}) diff --git a/pkg/server/endpoints/authorized_entryfetcher_attested_nodes_test.go b/pkg/server/endpoints/authorized_entryfetcher_attested_nodes_test.go index ec1b34054a..2a05ef3877 100644 --- a/pkg/server/endpoints/authorized_entryfetcher_attested_nodes_test.go +++ b/pkg/server/endpoints/authorized_entryfetcher_attested_nodes_test.go @@ -78,6 +78,13 @@ func TestLoadNodeCache(t *testing.T) { }, expectedError: "any error, doesn't matter", }, + { + name: "loading with a non-positive page size raises an error", + setup: &nodeScenarioSetup{ + pageSize: -1, + }, + expectedError: "page size must be positive, got -1", + }, { name: "initial load loads nothing", }, @@ -1432,6 +1439,39 @@ func TestUpdateAttestedNodesCache(t *testing.T) { expectedAuthorizedEntries: []string{}, }, + { + name: "empty cache, fetch five nodes spanning multiple pages, three new and two deletes", + setup: &nodeScenarioSetup{ + pageSize: 2, + }, + createAttestedNodes: []*common.AttestedNode{ + { + SpiffeId: "spiffe://example.org/test_node_1", + CertNotAfter: time.Now().Add(time.Duration(240) * time.Hour).Unix(), + }, + { + SpiffeId: "spiffe://example.org/test_node_3", + CertNotAfter: time.Now().Add(time.Duration(240) * time.Hour).Unix(), + }, + { + SpiffeId: "spiffe://example.org/test_node_5", + CertNotAfter: time.Now().Add(time.Duration(240) * time.Hour).Unix(), + }, + }, + fetchNodes: []string{ + "spiffe://example.org/test_node_1", + "spiffe://example.org/test_node_2", + "spiffe://example.org/test_node_3", + "spiffe://example.org/test_node_4", + "spiffe://example.org/test_node_5", + }, + + expectedAuthorizedEntries: []string{ + "spiffe://example.org/test_node_1", + "spiffe://example.org/test_node_3", + "spiffe://example.org/test_node_5", + }, + }, } { t.Run(tt.name, func(t *testing.T) { scenario := NewNodeScenario(t, tt.setup) @@ -1472,19 +1512,21 @@ func TestUpdateAttestedNodesCache(t *testing.T) { // utility functions type scenario struct { - ctx context.Context - log *logrus.Logger - hook *test.Hook - clk *clock.Mock - cache *authorizedentries.Cache - metrics *fakemetrics.FakeMetrics - ds *fakedatastore.DataStore + ctx context.Context + log *logrus.Logger + hook *test.Hook + clk *clock.Mock + cache *authorizedentries.Cache + metrics *fakemetrics.FakeMetrics + ds *fakedatastore.DataStore + pageSize int32 } type nodeScenarioSetup struct { attestedNodes []*common.AttestedNode attestedNodeEvents []*datastore.AttestedNodeEvent err error + pageSize int32 } func NewNodeScenario(t *testing.T, setup *nodeScenarioSetup) *scenario { @@ -1500,6 +1542,10 @@ func NewNodeScenario(t *testing.T, setup *nodeScenarioSetup) *scenario { if setup == nil { setup = &nodeScenarioSetup{} } + pageSize := setup.pageSize + if pageSize == 0 { + pageSize = 1024 + } var err error // initialize the database @@ -1522,13 +1568,14 @@ func NewNodeScenario(t *testing.T, setup *nodeScenarioSetup) *scenario { } return &scenario{ - ctx: ctx, - log: log, - hook: hook, - clk: clk, - cache: cache, - metrics: metrics, - ds: ds, + ctx: ctx, + log: log, + hook: hook, + clk: clk, + cache: cache, + metrics: metrics, + ds: ds, + pageSize: pageSize, } } @@ -1538,7 +1585,7 @@ func (s *scenario) buildAttestedNodesCache() (*attestedNodes, error) { return nil, err } - attestedNodes, err := buildAttestedNodesCache(s.ctx, s.log, s.metrics, s.ds, s.clk, s.cache, nodeCache, defaultCacheReloadInterval, defaultEventTimeout) + attestedNodes, err := buildAttestedNodesCache(s.ctx, s.log, s.metrics, s.ds, s.clk, s.cache, nodeCache, s.pageSize, defaultCacheReloadInterval, defaultEventTimeout) if attestedNodes != nil { // clear out the fetches for node := range attestedNodes.fetchNodes { diff --git a/pkg/server/endpoints/authorized_entryfetcher_test.go b/pkg/server/endpoints/authorized_entryfetcher_test.go index 62da8fdde5..3286616d9a 100644 --- a/pkg/server/endpoints/authorized_entryfetcher_test.go +++ b/pkg/server/endpoints/authorized_entryfetcher_test.go @@ -215,7 +215,7 @@ func TestBuildCacheSavesSkippedEvents(t *testing.T) { require.NoError(t, err) require.NotNil(t, registrationEntries) - attestedNodes, err := buildAttestedNodesCache(ctx, log, metrics, ds, clk, cache, nodeCache, defaultCacheReloadInterval, defaultEventTimeout) + attestedNodes, err := buildAttestedNodesCache(ctx, log, metrics, ds, clk, cache, nodeCache, pageSize, defaultCacheReloadInterval, defaultEventTimeout) require.NoError(t, err) require.NotNil(t, attestedNodes) diff --git a/test/fakes/fakedatastore/fakedatastore.go b/test/fakes/fakedatastore/fakedatastore.go index 33a56a276d..e5fd0672d1 100644 --- a/test/fakes/fakedatastore/fakedatastore.go +++ b/test/fakes/fakedatastore/fakedatastore.go @@ -148,6 +148,13 @@ func (s *DataStore) FetchAttestedNode(ctx context.Context, spiffeID string) (*co return s.ds.FetchAttestedNode(ctx, spiffeID) } +func (s *DataStore) FetchAttestedNodes(ctx context.Context, spiffeIDs []string) (map[string]*common.AttestedNode, error) { + if err := s.getNextError(); err != nil { + return nil, err + } + return s.ds.FetchAttestedNodes(ctx, spiffeIDs) +} + func (s *DataStore) ListAttestedNodes(ctx context.Context, req *datastore.ListAttestedNodesRequest) (*datastore.ListAttestedNodesResponse, error) { if err := s.getNextError(); err != nil { return nil, err