From bc249b3b28926774c53bd0cd9c74d8cc7f61165f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Mart=C3=ADnez=20Fay=C3=B3?= Date: Wed, 22 Apr 2026 13:07:20 -0300 Subject: [PATCH 1/2] Add tag-based key discovery support in the `aws_kms` KeyManager plugin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Martínez Fayó --- doc/plugin_server_keymanager_aws_kms.md | 85 ++- go.mod | 7 +- go.sum | 14 +- pkg/server/plugin/keymanager/awskms/awskms.go | 412 +++++++++++- .../plugin/keymanager/awskms/awskms_test.go | 632 +++++++++++++++++- pkg/server/plugin/keymanager/awskms/client.go | 10 + .../plugin/keymanager/awskms/client_fake.go | 63 +- .../plugin/keymanager/awskms/fetcher.go | 252 ++++++- 8 files changed, 1408 insertions(+), 67 deletions(-) diff --git a/doc/plugin_server_keymanager_aws_kms.md b/doc/plugin_server_keymanager_aws_kms.md index 2d1679c3ae..569f118eee 100644 --- a/doc/plugin_server_keymanager_aws_kms.md +++ b/doc/plugin_server_keymanager_aws_kms.md @@ -6,16 +6,17 @@ The `aws_kms` key manager plugin leverages the AWS Key Management Service (KMS) The plugin accepts the following configuration options: -| Key | Type | Required | Description | Default | -|----------------------|--------|---------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------| -| access_key_id | string | see [AWS KMS Access](#aws-kms-access) | The Access Key Id used to authenticate to KMS | Value of the AWS_ACCESS_KEY_ID environment variable | -| secret_access_key | string | see [AWS KMS Access](#aws-kms-access) | The Secret Access Key used to authenticate to KMS | Value of the AWS_SECRET_ACCESS_KEY environment variable | -| region | string | yes | The region where the keys will be stored | | -| key_identifier_file | string | Required if key_identifier_value is not set | A file path location where information about generated keys will be persisted | | -| key_identifier_value | string | Required if key_identifier_file is not set | A static identifier for the SPIRE server instance (used instead of `key_identifier_file`) | | -| key_policy_file | string | no | A file path location to a custom key policy in JSON format | "" | - -### Alias and Key Management +| Key | Type | Required | Description | Default | +|----------------------------------|---------|---------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------| +| access_key_id | string | see [AWS KMS Access](#aws-kms-access) | The Access Key Id used to authenticate to KMS | Value of the AWS_ACCESS_KEY_ID environment variable | +| secret_access_key | string | see [AWS KMS Access](#aws-kms-access) | The Secret Access Key used to authenticate to KMS | Value of the AWS_SECRET_ACCESS_KEY environment variable | +| region | string | yes | The region where the keys will be stored | | +| key_identifier_file | string | Required if key_identifier_value is not set | A file path location where information about generated keys will be persisted | | +| key_identifier_value | string | Required if key_identifier_file is not set | A static identifier for the SPIRE server instance (used instead of `key_identifier_file`) | | +| key_policy_file | string | no | A file path location to a custom key policy in JSON format | "" | +| enable_tag_based_key_discovery | boolean | no | Enable tag-based key discovery (recommended). See [Tag-based Key Discovery](#tag-based-key-discovery). | false | + +### Server Instance Identification The plugin needs a way to identify the specific server instance where it's running. For that, either the `key_identifier_file` or `key_identifier_value` @@ -32,9 +33,40 @@ If you need more control over the identifier that's used for the server, the static identifier for the server instance. This setting is appropriate in situations where a key identifier file can't be persisted. -The plugin assigns [aliases](https://docs.aws.amazon.com/kms/latest/developerguide/kms-alias.html) to the Customer Master Keys that it manages. The aliases are used to identify and name keys that are managed by the plugin. +### Tag-based Key Discovery + +> **Recommended.** Enable with `enable_tag_based_key_discovery = true`. + +When tag-based key discovery is enabled, the plugin uses the [AWS Resource Groups Tagging API](https://docs.aws.amazon.com/resourcegroupstagging/latest/APIReference/overview.html) to efficiently find only the KMS keys managed by this plugin instance. Keys are identified using the following SPIRE-specific tags applied at creation time: + +| Tag key | Description | +|---------------------|--------------------------------------------------------------------------------| +| `spire-server-td` | The trust domain of the SPIRE server | +| `spire-server-id` | The server instance identifier | +| `spire-active` | Set to `true` for actively managed keys | +| `spire-key-id` | The SPIRE key identifier | +| `spire-last-update` | Unix timestamp of the last update (set at creation and refreshed periodically) | + +The plugin stamps `spire-last-update` when a key is created (and when migrating an existing key) and periodically refreshes it on all active keys. Any key whose `spire-last-update` timestamp is older than two weeks is considered stale, marked inactive (`spire-active=false`), and its associated KMS key is scheduled for deletion. + +**Migration from alias-based discovery:** When enabling tag-based discovery on a server that previously used alias-based discovery, the plugin automatically detects existing untagged keys during startup and applies SPIRE tags to them. No manual migration steps are required. + +This mode requires the `tag:GetResources` permission (from the Resource Groups Tagging API). See [AWS KMS Access](#aws-kms-access) for the full permission list. + +#### Configuration consistency in HA deployments + +The `enable_tag_based_key_discovery` setting should be configured consistently across all SPIRE servers in the same trust domain. As with alias-based discovery, each server periodically refreshes a liveness signal on the keys it manages, and any server will reclaim keys in its trust domain whose signal has not been refreshed for two weeks (this is how keys belonging to a permanently-removed server are cleaned up). + +The two discovery modes use independent liveness signals: alias-based discovery refreshes the alias `LastUpdatedDate`, while tag-based discovery refreshes the `spire-last-update` tag. A server only refreshes the signal for its currently configured mode. This has an implication for rollbacks: + +- Migrating forward (alias-based to tag-based) across the fleet is safe: keys are tagged automatically on startup and refreshed from then on. +- Rolling back from tag-based to alias-based on part of the fleet while other servers remain tag-based is only safe within the two-week window. A rolled-back server keeps its aliases fresh but stops refreshing `spire-last-update`, so after two weeks the still tag-based servers will treat its in-use keys as abandoned and schedule them for deletion. To roll back safely, change the setting across all servers in the trust domain within that window. + +### Alias-based Key Discovery + +> **Note:** Alias-based key discovery will be deprecated in a future version and removed in a later one. Enable `enable_tag_based_key_discovery` to migrate. -Aliases managed by the plugin have the following form: `alias/SPIRE_SERVER/{TRUST_DOMAIN}/{SERVER_ID}/{KEY_ID}`. The `{SERVER_ID}` is the identifier handled by the `key_identifier_file` or `key_identifier_value` setting. This ID allows multiple servers in the same trust domain (e.g. servers in HA deployments) to manage keys with identical `{KEY_ID}`'s without collision. The `{KEY_ID}` in the alias name is encoded to use a [character set accepted by KMS](https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html#API_CreateAlias_RequestSyntax). +By default, the plugin uses alias-based key discovery. The plugin assigns [aliases](https://docs.aws.amazon.com/kms/latest/developerguide/kms-alias.html) to the Customer Master Keys that it manages. Aliases have the following form: `alias/SPIRE_SERVER/{TRUST_DOMAIN}/{SERVER_ID}/{KEY_ID}`. The `{KEY_ID}` in the alias name is encoded to use a [character set accepted by KMS](https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html#API_CreateAlias_RequestSyntax). The plugin attempts to detect and prune stale aliases. To facilitate stale alias detection, the plugin actively updates the `LastUpdatedDate` field on all aliases every 6 hours. The plugin periodically scans aliases. Any alias encountered with a `LastUpdatedDate` older than two weeks is removed, along with its associated key. @@ -71,14 +103,22 @@ The IAM role must have an attached policy with the following permissions: - `kms:CreateKey` - `kms:DescribeKey` - `kms:GetPublicKey` -- `kms:ListKeys` - `kms:ListAliases` - `kms:ScheduleKeyDeletion` - `kms:Sign` -- `kms:TagResource` (required when using key tagging) - `kms:UpdateAlias` - `kms:DeleteAlias` +The following additional permissions are required depending on the configuration: + +| Permission | Required when | +|--------------------|-------------------------------------------------------------------------| +| `kms:ListKeys` | Using alias-based key discovery (current default) | +| `kms:TagResource` | Using tag-based key discovery or `key_tags` | +| `tag:GetResources` | Using tag-based key discovery (`enable_tag_based_key_discovery = true`) | + +`tag:GetResources` belongs to the Resource Groups Tagging API, not to KMS. It is an identity-based permission and must be granted in the IAM identity's policy. It cannot be granted through the KMS key policy (including the default policy generated by the plugin). + ### Key policy The plugin can generate keys using a default key policy, or it can load and use a user defined policy. @@ -134,7 +174,19 @@ is set, the plugin uses the policy defined in the file instead of the default po KeyManager "aws_kms" { plugin_data { region = "us-east-2" - key_metadata_file = "./key_metadata" + key_identifier_file = "./key_metadata" + } +} +``` + +### Configuration with tag-based key discovery (recommended) + +```hcl +KeyManager "aws_kms" { + plugin_data { + region = "us-east-2" + key_identifier_file = "./key_metadata" + enable_tag_based_key_discovery = true } } ``` @@ -145,7 +197,8 @@ KeyManager "aws_kms" { KeyManager "aws_kms" { plugin_data { region = "us-east-2" - key_metadata_file = "./key_metadata" + key_identifier_file = "./key_metadata" + enable_tag_based_key_discovery = true key_tags = { Environment = "production" Team = "security" diff --git a/go.mod b/go.mod index e498fa4b50..d0fdc6fa79 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/Masterminds/sprig/v3 v3.3.0 github.com/Microsoft/go-winio v0.6.2 github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 - github.com/aws/aws-sdk-go-v2 v1.41.7 + github.com/aws/aws-sdk-go-v2 v1.41.9 github.com/aws/aws-sdk-go-v2/config v1.32.14 github.com/aws/aws-sdk-go-v2/credentials v1.19.14 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 @@ -32,6 +32,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/iam v1.53.1 github.com/aws/aws-sdk-go-v2/service/kms v1.52.0 github.com/aws/aws-sdk-go-v2/service/organizations v1.51.1 + github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi v1.32.2 github.com/aws/aws-sdk-go-v2/service/rolesanywhere v1.23.0 github.com/aws/aws-sdk-go-v2/service/s3 v1.102.0 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.0 @@ -126,8 +127,8 @@ require ( github.com/armon/go-radix v1.0.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.25 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.25 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect diff --git a/go.sum b/go.sum index 260589f24d..d6b092d1bb 100644 --- a/go.sum +++ b/go.sum @@ -117,8 +117,8 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-sdk-go v1.55.8 h1:JRmEUbU52aJQZ2AjX4q4Wu7t4uZjOu71uyNmaWlUkJQ= github.com/aws/aws-sdk-go v1.55.8/go.mod h1:ZkViS9AqA6otK+JBBNH2++sx1sgxrPKcSzPPvQkUtXk= -github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8= -github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc= +github.com/aws/aws-sdk-go-v2 v1.41.9 h1:/rYeyO2+HrMztAmxAq9++XJtFMqSIpSsNA0yDGALYq4= +github.com/aws/aws-sdk-go-v2 v1.41.9/go.mod h1:+HsoOEX80qAVUitj1A2DhCNTjmb3edVyuDypb6LNEeo= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10/go.mod h1:qqY157uZoqm5OXq/amuaBJyC9hgBCBQnsaWnPe905GY= github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI= @@ -129,10 +129,10 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeD github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.0 h1:SE3IDYzg2WwsAmkxSnEGuW/Bek8js245j1lGwZJpl1E= github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.6.0/go.mod h1:duFNXIVHPkyfllpU5GuJ+QoiETTsDWSOMvpOEcy5Kss= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.25 h1:Uii3frf9ztec/ABM2/FSH9/z7PLzxfpG8h4RpkUFflQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.25/go.mod h1:G6kntsA2GorAxDPbap6xgB2F+amSLUF8GJTi7PUoX44= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.25 h1:r1+/l6m+WaUJF9HISEsNOLHSNj5EXYQxK8VX6Cz9NlA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.25/go.mod h1:cKf+D+NMDK1LndD7BowHbBZPgR9V0/5HubH0PFWvA+c= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA= @@ -159,6 +159,8 @@ github.com/aws/aws-sdk-go-v2/service/kms v1.52.0 h1:QNtg+Mtj1zmepk568+UKBD5DFfqh github.com/aws/aws-sdk-go-v2/service/kms v1.52.0/go.mod h1:Y0+uxvxz6ib4KktRdK0V4X45Vcs/JyYoz8H71pO8xeI= github.com/aws/aws-sdk-go-v2/service/organizations v1.51.1 h1:5hM1jQjIzEiu07ZqQ8iI4sC+06C8a+idNtytO65dhAw= github.com/aws/aws-sdk-go-v2/service/organizations v1.51.1/go.mod h1:urLFj1twuR/h5T0wN/2/kmY1gxBFa1tTKr+c60lZ2fA= +github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi v1.32.2 h1:LC3ALu3cQVkh7umM+x8zE0UxVWS/gllEt5VuNchyUW8= +github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi v1.32.2/go.mod h1:gBZ5iZqcOsvR8pIZS0CsbGfoUUEyiS8qjxQXRjdsxZA= github.com/aws/aws-sdk-go-v2/service/rolesanywhere v1.23.0 h1:BD6q8KEiom1qJpDC7kmBrKwMXCRkuSy7JXWPhQThFkI= github.com/aws/aws-sdk-go-v2/service/rolesanywhere v1.23.0/go.mod h1:kJgqmGmlOPqTe9cntPrSgMdIQ3cAama8TAW+NKaDCzE= github.com/aws/aws-sdk-go-v2/service/s3 v1.102.0 h1:gfPQ6do5PZTCc5n/vZUHz/G8McrNrfERGSO+iHvVbCA= diff --git a/pkg/server/plugin/keymanager/awskms/awskms.go b/pkg/server/plugin/keymanager/awskms/awskms.go index f312595d98..7d92949e13 100644 --- a/pkg/server/plugin/keymanager/awskms/awskms.go +++ b/pkg/server/plugin/keymanager/awskms/awskms.go @@ -9,6 +9,7 @@ import ( "os" "path" "regexp" + "strconv" "strings" "sync" "time" @@ -17,6 +18,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" + rgtatypes "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi/types" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gofrs/uuid/v5" "github.com/hashicorp/go-hclog" @@ -34,16 +37,29 @@ const ( pluginName = "aws_kms" aliasPrefix = "alias/SPIRE_SERVER/" + // Logging tags keyArnTag = "key_arn" aliasNameTag = "alias_name" reasonTag = "reason" + // KMS resource tags for key discovery + tagKeyServerTD = "spire-server-td" // Trust domain (no hashing needed - AWS allows dots and long values) + tagKeyServerID = "spire-server-id" // Server identifier + tagKeyLastUpdate = "spire-last-update" // Unix timestamp of last update + tagKeyActive = "spire-active" // "true" if key is actively managed + tagKeySPIREKeyID = "spire-key-id" // SPIRE key identifier + + // Alias-based discovery task frequencies (legacy; will be deprecated in a future version and removed in a later one) refreshAliasesFrequency = time.Hour * 6 disposeAliasesFrequency = time.Hour * 24 aliasThreshold = time.Hour * 24 * 14 // two weeks + disposeKeysFrequency = time.Hour * 48 + keyThreshold = time.Hour * 48 // 48 hours for orphaned keys without aliases - disposeKeysFrequency = time.Hour * 48 - keyThreshold = time.Hour * 48 + // Tag-based discovery task frequencies + keepActiveKeysFrequency = time.Hour * 6 + disposeKeysViaTagsFrequency = time.Hour * 48 + keyThresholdForTagDiscovery = time.Hour * 24 * 14 // two weeks for tagged keys ) var ( @@ -69,14 +85,16 @@ type keyEntry struct { } type pluginHooks struct { - newKMSClient func(aws.Config) (kmsClient, error) - newSTSClient func(aws.Config) (stsClient, error) - clk clock.Clock + newKMSClient func(aws.Config) (kmsClient, error) + newTaggingClient func(aws.Config) (taggingClient, error) + newSTSClient func(aws.Config) (stsClient, error) + clk clock.Clock // just for testing scheduleDeleteSignal chan error refreshAliasesSignal chan error disposeAliasesSignal chan error disposeKeysSignal chan error + keepActiveKeysSignal chan error } // Plugin is the main representation of this keymanager plugin @@ -88,6 +106,7 @@ type Plugin struct { mu sync.RWMutex entries map[string]keyEntry kmsClient kmsClient + taggingClient taggingClient stsClient stsClient trustDomain string serverID string @@ -96,6 +115,9 @@ type Plugin struct { hooks pluginHooks keyPolicy *string keyTags []types.Tag + + // useTagBasedDiscovery indicates whether to use tag-based or alias-based key discovery + useTagBasedDiscovery bool } // Config provides configuration context for the plugin @@ -107,6 +129,20 @@ type Config struct { KeyIdentifierValue string `hcl:"key_identifier_value" json:"key_identifier_value"` KeyPolicyFile string `hcl:"key_policy_file" json:"key_policy_file"` KeyTags map[string]string `hcl:"key_tags" json:"key_tags"` + + // EnableTagBasedKeyDiscovery enables the use of AWS Resource Groups Tagging API + // for efficient key discovery instead of the legacy alias-based approach. + // When enabled, keys are discovered using SPIRE-specific tags (spire-server-td, + // spire-server-id, spire-active). + // This eliminates the need for broad ListKeys + DescribeKey permissions and reduces API costs. + // + // Default: false (uses legacy alias-based discovery) + // In a future SPIRE version, this will default to true. The alias-based + // approach will be deprecated in a future version and removed in a later one. + // + // Note: When enabled, the plugin requires permission to use the + // resourcegroupstaggingapi:GetResources API action. + EnableTagBasedKeyDiscovery bool `hcl:"enable_tag_based_key_discovery" json:"enable_tag_based_key_discovery"` } func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { @@ -152,19 +188,21 @@ func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginco // New returns an instantiated plugin func New() *Plugin { - return newPlugin(newKMSClient, newSTSClient) + return newPlugin(newKMSClient, newTaggingClient, newSTSClient) } func newPlugin( newKMSClient func(aws.Config) (kmsClient, error), + newTaggingClient func(aws.Config) (taggingClient, error), newSTSClient func(aws.Config) (stsClient, error), ) *Plugin { return &Plugin{ entries: make(map[string]keyEntry), hooks: pluginHooks{ - newKMSClient: newKMSClient, - newSTSClient: newSTSClient, - clk: clock.New(), + newKMSClient: newKMSClient, + newTaggingClient: newTaggingClient, + newSTSClient: newSTSClient, + clk: clock.New(), }, scheduleDelete: make(chan string, 120), } @@ -215,16 +253,57 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) return nil, status.Errorf(codes.Internal, "failed to create KMS client: %v", err) } - fetcher := &keyFetcher{ - log: p.log, - kmsClient: kc, - serverID: serverID, - trustDomain: req.CoreConfiguration.TrustDomain, - } - p.log.Debug("Fetching key aliases from KMS") - keyEntries, err := fetcher.fetchKeyEntries(ctx) - if err != nil { - return nil, err + // Determine which discovery mode to use + useTagBasedDiscovery := newConfig.EnableTagBasedKeyDiscovery + + if useTagBasedDiscovery { + p.log.Info("Tag-based key discovery enabled") + } else { + p.log.Warn("Alias-based key discovery will be deprecated in a future version and removed in a later one. " + + "Enable 'enable_tag_based_key_discovery' to switch to tag-based discovery, which efficiently " + + "finds only the keys managed by this plugin instance.") + } + + // Initialize the appropriate fetcher based on configuration + var keyEntries []*keyEntry + var spireTags []types.Tag + if useTagBasedDiscovery { + // Build SPIRE-specific tags so they can be used during migration + // and applied to newly created keys. + spireTags = p.buildSPIRETags(serverID, req.CoreConfiguration.TrustDomain) + // Create tagging client for tag-based discovery + tc, err := p.hooks.newTaggingClient(awsCfg) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to create tagging client: %v", err) + } + + fetcher := &keyFetcher{ + log: p.log, + kmsClient: kc, + taggingClient: tc, + serverID: serverID, + trustDomain: req.CoreConfiguration.TrustDomain, + } + p.log.Debug("Fetching keys using tag-based discovery from AWS Resource Groups Tagging API") + lastUpdate := strconv.FormatInt(p.hooks.clk.Now().Unix(), 10) + keyEntries, err = fetcher.fetchKeyEntriesWithMigration(ctx, spireTags, lastUpdate) + if err != nil { + return nil, err + } + p.taggingClient = tc + } else { + // Use legacy alias-based discovery + fetcher := &keyFetcher{ + log: p.log, + kmsClient: kc, + serverID: serverID, + trustDomain: req.CoreConfiguration.TrustDomain, + } + p.log.Debug("Fetching keys using legacy alias-based discovery from KMS") + keyEntries, err = fetcher.fetchKeyEntriesViaAlias(ctx) + if err != nil { + return nil, err + } } p.mu.Lock() @@ -235,10 +314,23 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) p.stsClient = sc p.trustDomain = req.CoreConfiguration.TrustDomain p.serverID = serverID + p.useTagBasedDiscovery = useTagBasedDiscovery - if len(newConfig.KeyTags) > 0 { + // Build the tag list applied to every new key. SPIRE-specific tags are + // only included when tag-based discovery is enabled, so that the legacy + // alias-based path does not require the kms:TagResource permission. + switch { + case useTagBasedDiscovery && len(newConfig.KeyTags) > 0: + userTags := buildKeyTags(newConfig.KeyTags) + // Build a fresh slice to avoid mutating the spireTags backing array. + p.keyTags = make([]types.Tag, 0, len(spireTags)+len(userTags)) + p.keyTags = append(p.keyTags, spireTags...) + p.keyTags = append(p.keyTags, userTags...) + case useTagBasedDiscovery: + p.keyTags = spireTags + case len(newConfig.KeyTags) > 0: p.keyTags = buildKeyTags(newConfig.KeyTags) - } else { + default: p.keyTags = nil } @@ -247,12 +339,21 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) p.cancelTasks() } - // start tasks + // Start background tasks based on discovery mode ctx, p.cancelTasks = context.WithCancel(context.Background()) go p.scheduleDeleteTask(ctx) + + // Always refresh aliases so a downgrade to a version without + // tag-based discovery still finds keys with fresh aliases. go p.refreshAliasesTask(ctx) - go p.disposeAliasesTask(ctx) - go p.disposeKeysTask(ctx) + + if useTagBasedDiscovery { + go p.keepActiveKeysTask(ctx) + go p.disposeKeysViaTagsTask(ctx) + } else { + go p.disposeAliasesTask(ctx) + go p.disposeKeysTask(ctx) + } return &configv1.ConfigureResponse{}, nil } @@ -387,7 +488,30 @@ func (p *Plugin) createKey(ctx context.Context, spireKeyID string, keyType keyma Policy: p.keyPolicy, } - if len(p.keyTags) > 0 { + if p.useTagBasedDiscovery { + // When tag-based discovery is enabled, append the per-key SPIRE key + // ID tag so the key can be looked up by ID via the tagging API, and + // stamp spire-last-update so the key is immediately eligible for + // staleness evaluation. Stamping at creation (rather than waiting for + // the first keepActiveKeys tick) ensures a key is never left with + // spire-active=true but no spire-last-update, which would make it + // undisposable by other servers if this server dies before that tick. + // Build a fresh slice to avoid mutating the shared p.keyTags slice. + tags := make([]types.Tag, len(p.keyTags), len(p.keyTags)+2) + copy(tags, p.keyTags) + tags = append(tags, + types.Tag{ + TagKey: aws.String(tagKeySPIREKeyID), + TagValue: aws.String(spireKeyID), + }, + types.Tag{ + TagKey: aws.String(tagKeyLastUpdate), + TagValue: aws.String(strconv.FormatInt(p.hooks.clk.Now().Unix(), 10)), + }, + ) + createKeyInput.Tags = tags + } else if len(p.keyTags) > 0 { + // Legacy alias-based mode: only apply user-defined tags (if any). createKeyInput.Tags = p.keyTags } @@ -1078,6 +1202,29 @@ func buildKeyTags(tags map[string]string) []types.Tag { return keyTags } +// buildSPIRETags creates the SPIRE-specific tags that are added to all KMS keys +// at creation time. These tags enable efficient key discovery via the AWS +// Resource Groups Tagging API. +// +// Note: spire-last-update is intentionally omitted here. It is set exclusively +// by keepActiveKeys, which runs on a regular schedule. +func (p *Plugin) buildSPIRETags(serverID, trustDomain string) []types.Tag { + return []types.Tag{ + { + TagKey: aws.String(tagKeyServerTD), + TagValue: aws.String(trustDomain), + }, + { + TagKey: aws.String(tagKeyServerID), + TagValue: aws.String(serverID), + }, + { + TagKey: aws.String(tagKeyActive), + TagValue: aws.String("true"), + }, + } +} + // encodeKeyID maps "." and "+" characters to the asciihex value using "_" as // escape character. Currently, KMS does not support those characters to be used // as alias name. @@ -1094,3 +1241,218 @@ func decodeKeyID(keyID string) string { keyID = strings.ReplaceAll(keyID, "_2b", "+") return keyID } + +// keepActiveKeysTask updates the spire-last-update tag on all managed keys every 6 hours. +// This allows detection of keys that are no longer in use by any server. +// This task only runs when tag-based discovery is enabled. +func (p *Plugin) keepActiveKeysTask(ctx context.Context) { + ticker := p.hooks.clk.Ticker(keepActiveKeysFrequency) + defer ticker.Stop() + + p.notifyKeepActiveKeys(nil) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := p.keepActiveKeys(ctx) + p.notifyKeepActiveKeys(err) + } + } +} + +// keepActiveKeys updates the last-update tag on all keys managed by this server. +func (p *Plugin) keepActiveKeys(ctx context.Context) error { + p.log.Debug("Updating last-update tag on managed keys") + + // Snapshot entries under the lock so we don't hold it across network calls. + p.mu.RLock() + entries := make([]keyEntry, 0, len(p.entries)) + for _, e := range p.entries { + entries = append(entries, e) + } + p.mu.RUnlock() + + now := strconv.FormatInt(p.hooks.clk.Now().Unix(), 10) + var errs []string + + for _, entry := range entries { + _, err := p.kmsClient.TagResource(ctx, &kms.TagResourceInput{ + KeyId: &entry.Arn, + Tags: []types.Tag{ + { + TagKey: aws.String(tagKeyLastUpdate), + TagValue: aws.String(now), + }, + }, + }) + if err != nil { + p.log.Error("Failed to update last-update tag", keyArnTag, entry.Arn, reasonTag, err) + errs = append(errs, err.Error()) + } + } + + if errs != nil { + return errors.New(strings.Join(errs, ": ")) + } + return nil +} + +// disposeKeysViaTagsTask finds and disposes of stale keys using tag-based filtering. +// This runs every 48 hours and looks for keys with spire-active=true but with +// a spire-last-update timestamp older than 2 weeks that don't belong to this server. +// This is the tag-based equivalent of disposeAliasesTask. +func (p *Plugin) disposeKeysViaTagsTask(ctx context.Context) { + ticker := p.hooks.clk.Ticker(disposeKeysViaTagsFrequency) + defer ticker.Stop() + + p.notifyDisposeKeys(nil) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := p.disposeKeysViaTags(ctx) + p.notifyDisposeKeys(err) + } + } +} + +// disposeKeysViaTags uses the AWS Resource Groups Tagging API to find stale keys. +func (p *Plugin) disposeKeysViaTags(ctx context.Context) error { + p.log.Debug("Looking for stale keys to dispose using tag-based discovery") + + now := p.hooks.clk.Now() + staleThreshold := now.Add(-keyThresholdForTagDiscovery).Unix() + + // Find all keys in this trust domain that are active + tagFilters := []rgtatypes.TagFilter{ + { + Key: aws.String(tagKeyServerTD), + Values: []string{p.trustDomain}, + }, + { + Key: aws.String(tagKeyActive), + Values: []string{"true"}, + }, + } + + paginator := resourcegroupstaggingapi.NewGetResourcesPaginator(p.taggingClient, &resourcegroupstaggingapi.GetResourcesInput{ + ResourceTypeFilters: []string{"kms:key"}, + TagFilters: tagFilters, + }) + + var errs []string + for { + resourcesResp, err := paginator.NextPage(ctx) + switch { + case err != nil: + if permErr := tagGetResourcesPermissionError(err); permErr != nil { + p.log.Error("Failed to fetch keys for disposal", reasonTag, permErr) + return permErr + } + p.log.Error("Failed to fetch keys for disposal", reasonTag, err) + return err + case resourcesResp == nil: + p.log.Error("Failed to fetch keys for disposal: nil response") + return errors.New("nil response from tagging API") + } + + for _, resource := range resourcesResp.ResourceTagMappingList { + if resource.ResourceARN == nil { + continue + } + + keyArn := *resource.ResourceARN + + // Check if this key belongs to the current server + var belongsToThisServer bool + var lastUpdateTimestamp int64 + var hasLastUpdate bool + var malformedTimestamp bool + for _, tag := range resource.Tags { + if tag.Key != nil && *tag.Key == tagKeyServerID && tag.Value != nil && *tag.Value == p.serverID { + belongsToThisServer = true + } + if tag.Key != nil && *tag.Key == tagKeyLastUpdate && tag.Value != nil { + ts, err := strconv.ParseInt(*tag.Value, 10, 64) + if err != nil { + malformedTimestamp = true + continue + } + lastUpdateTimestamp = ts + hasLastUpdate = true + } + } + + if malformedTimestamp && !hasLastUpdate { + p.log.Warn("Malformed spire-last-update tag value, skipping key", + keyArnTag, keyArn) + continue + } + + // Skip keys belonging to this server + if belongsToThisServer { + continue + } + + // Skip keys that have been updated recently. A key with no + // spire-last-update tag has lastUpdateTimestamp == 0 (the Unix + // epoch), the oldest possible value, so it is treated as stale: + // every key managed by the plugin is stamped at creation and + // migration, and a missing value indicates an abandoned key. + if lastUpdateTimestamp > staleThreshold { + continue + } + + log := p.log.With(keyArnTag, keyArn) + log.Debug("Found stale key beyond threshold") + + // Schedule the key for deletion. Only mark it inactive once it has + // been enqueued: marking spire-active=false drops the key from the + // GetResources(active=true) filter, so if the queue is full and we + // skipped enqueueing, leaving it active=true lets the next cycle + // retry it. Tag mode has no creation-date orphan sweeper to fall + // back on, unlike the legacy alias path. + select { + case p.scheduleDelete <- keyArn: + log.Debug("Key enqueued for deletion") + default: + log.Error("Failed to enqueue key for deletion; leaving key active for retry on the next cycle") + continue + } + + // Mark the key as inactive by updating the spire-active tag. + _, err := p.kmsClient.TagResource(ctx, &kms.TagResourceInput{ + KeyId: &keyArn, + Tags: []types.Tag{ + { + TagKey: aws.String(tagKeyActive), + TagValue: aws.String("false"), + }, + }, + }) + if err != nil { + log.Error("Failed to mark key as inactive", reasonTag, err) + errs = append(errs, err.Error()) + } + } + + if !paginator.HasMorePages() { + break + } + } + + if errs != nil { + return errors.New(strings.Join(errs, ": ")) + } + return nil +} + +func (p *Plugin) notifyKeepActiveKeys(err error) { + if p.hooks.keepActiveKeysSignal != nil { + p.hooks.keepActiveKeysSignal <- err + } +} diff --git a/pkg/server/plugin/keymanager/awskms/awskms_test.go b/pkg/server/plugin/keymanager/awskms/awskms_test.go index 217f319285..3981ccda77 100644 --- a/pkg/server/plugin/keymanager/awskms/awskms_test.go +++ b/pkg/server/plugin/keymanager/awskms/awskms_test.go @@ -12,6 +12,7 @@ import ( "path" "path/filepath" "runtime" + "strconv" "strings" "testing" "time" @@ -19,6 +20,8 @@ import ( "github.com/andres-erbsen/clock" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kms/types" + rgtatypes "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi/types" + "github.com/aws/smithy-go" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/spiffe/go-spiffe/v2/spiffeid" @@ -90,8 +93,10 @@ func TestKeyManagerContract(t *testing.T) { c := clock.NewMock() fakeKMSClient := newKMSClientFake(t, c) fakeSTSClient := newSTSClientFake() + fakeTaggingClient := newTaggingClientFake() p := newPlugin( func(aws.Config) (kmsClient, error) { return fakeKMSClient, nil }, + func(aws.Config) (taggingClient, error) { return fakeTaggingClient, nil }, func(aws.Config) (stsClient, error) { return fakeSTSClient, nil }, ) km := new(keymanager.V1) @@ -121,11 +126,12 @@ func TestKeyManagerContract(t *testing.T) { } type pluginTest struct { - plugin *Plugin - fakeKMSClient *kmsClientFake - fakeSTSClient *stsClientFake - logHook *test.Hook - clockHook *clock.Mock + plugin *Plugin + fakeKMSClient *kmsClientFake + fakeSTSClient *stsClientFake + fakeTaggingClient *taggingClientFake + logHook *test.Hook + clockHook *clock.Mock } func setupTest(t *testing.T) *pluginTest { @@ -135,8 +141,10 @@ func setupTest(t *testing.T) *pluginTest { c := clock.NewMock() fakeKMSClient := newKMSClientFake(t, c) fakeSTSClient := newSTSClientFake() + fakeTaggingClient := newTaggingClientFake() p := newPlugin( func(aws.Config) (kmsClient, error) { return fakeKMSClient, nil }, + func(aws.Config) (taggingClient, error) { return fakeTaggingClient, nil }, func(aws.Config) (stsClient, error) { return fakeSTSClient, nil }, ) km := new(keymanager.V1) @@ -145,11 +153,12 @@ func setupTest(t *testing.T) *pluginTest { p.hooks.clk = c return &pluginTest{ - plugin: p, - fakeKMSClient: fakeKMSClient, - fakeSTSClient: fakeSTSClient, - logHook: logHook, - clockHook: c, + plugin: p, + fakeKMSClient: fakeKMSClient, + fakeSTSClient: fakeSTSClient, + fakeTaggingClient: fakeTaggingClient, + logHook: logHook, + clockHook: c, } } @@ -2563,3 +2572,606 @@ func waitForSignal(t *testing.T, ch chan error) error { } return nil } + +// configureTagBasedRequest returns a ConfigureRequest with tag-based key +// discovery enabled, using the validServerID file for the server identifier. +func configureTagBasedRequest(t *testing.T) *configv1.ConfigureRequest { + return &configv1.ConfigureRequest{ + CoreConfiguration: &configv1.CoreConfiguration{TrustDomain: "test.example.org"}, + HclConfiguration: fmt.Sprintf(`{ + "access_key_id": %q, + "secret_access_key": %q, + "region": %q, + "key_identifier_file": %q, + "enable_tag_based_key_discovery": true + }`, validAccessKeyID, validSecretAccessKey, validRegion, getKeyIdentifierFile(t)), + } +} + +// makeTaggedResource builds a ResourceTagMapping representing a KMS key that +// is actively managed by the given server, with spire-key-id set. +func makeTaggedResource(keyArn, spireKeyID, serverID, trustDomain string) rgtatypes.ResourceTagMapping { + return rgtatypes.ResourceTagMapping{ + ResourceARN: aws.String(keyArn), + Tags: []rgtatypes.Tag{ + {Key: aws.String(tagKeyServerTD), Value: aws.String(trustDomain)}, + {Key: aws.String(tagKeyServerID), Value: aws.String(serverID)}, + {Key: aws.String(tagKeyActive), Value: aws.String("true")}, + {Key: aws.String(tagKeySPIREKeyID), Value: aws.String(spireKeyID)}, + }, + } +} + +func TestConfigureWithTagBasedDiscovery(t *testing.T) { + const ( + tagKeyID = "tag-key-01" + tagKeyArn = fakeKeyArnPrefix + tagKeyID + tagSpireKey = "x509-CA-A" + ) + + for _, tt := range []struct { + name string + err string + code codes.Code + fakeEntries []fakeKeyEntry + taggedResources []rgtatypes.ResourceTagMapping + getResourcesErr error + tagResourceErr string + expectEntryCount int + expectTagResourceCalls int + }{ + { + name: "load keys via tag-based discovery", + fakeEntries: []fakeKeyEntry{ + { + KeyID: aws.String(tagKeyID), + KeySpec: types.KeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("fake-public-key"), + }, + }, + taggedResources: []rgtatypes.ResourceTagMapping{ + makeTaggedResource(tagKeyArn, tagSpireKey, validServerID, "test.example.org"), + }, + expectEntryCount: 1, + expectTagResourceCalls: 0, + }, + { + name: "migrate legacy keys to tag-based discovery", + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.KeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("fake-public-key"), + }, + }, + // tagging fake returns empty — key has no SPIRE tags yet + expectEntryCount: 1, + expectTagResourceCalls: 1, + }, + { + name: "deduplicate key found via both tag and alias", + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.KeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("fake-public-key"), + }, + }, + // tagging fake already has the key — no migration needed + taggedResources: []rgtatypes.ResourceTagMapping{ + makeTaggedResource(KeyArn, spireKeyID, validServerID, "test.example.org"), + }, + expectEntryCount: 1, + expectTagResourceCalls: 0, + }, + { + name: "migration: tag resource error is non-fatal", + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.KeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("fake-public-key"), + }, + }, + tagResourceErr: "tag resource failed", + expectEntryCount: 1, + }, + { + // A key found via tags but disabled or pending deletion (e.g. + // scheduled for deletion by the legacy alias path, which does not + // clear SPIRE tags) is skipped rather than failing startup. + name: "skips disabled key found via tags", + fakeEntries: []fakeKeyEntry{ + { + KeyID: aws.String(tagKeyID), + KeySpec: types.KeySpecEccNistP256, + Enabled: false, + PublicKey: []byte("fake-public-key"), + }, + }, + taggedResources: []rgtatypes.ResourceTagMapping{ + makeTaggedResource(tagKeyArn, tagSpireKey, validServerID, "test.example.org"), + }, + expectEntryCount: 0, + expectTagResourceCalls: 0, + }, + { + name: "GetResources error fails configure", + getResourcesErr: errors.New("tagging API unavailable"), + err: "failed to fetch keys by tags: tagging API unavailable", + code: codes.Internal, + }, + { + name: "GetResources access denied returns actionable error", + getResourcesErr: &smithy.GenericAPIError{ + Code: "AccessDeniedException", + Message: "User is not authorized to perform: tag:GetResources", + }, + err: `tag-based key discovery requires the "tag:GetResources" permission`, + code: codes.FailedPrecondition, + }, + } { + t.Run(tt.name, func(t *testing.T) { + ts := setupTest(t) + ts.fakeKMSClient.setEntries(tt.fakeEntries) + ts.fakeTaggingClient.setResources(tt.taggedResources) + ts.fakeTaggingClient.setErr(tt.getResourcesErr) + ts.fakeKMSClient.setTagResourceErr(tt.tagResourceErr) + + _, err := ts.plugin.Configure(ctx, configureTagBasedRequest(t)) + + if tt.err != "" { + spiretest.RequireGRPCStatusContains(t, err, tt.code, tt.err) + return + } + require.NoError(t, err) + + ts.plugin.mu.RLock() + entryCount := len(ts.plugin.entries) + ts.plugin.mu.RUnlock() + require.Equal(t, tt.expectEntryCount, entryCount) + + ts.fakeKMSClient.mu.RLock() + tagCalls := len(ts.fakeKMSClient.tagResourceCalls) + ts.fakeKMSClient.mu.RUnlock() + require.Equal(t, tt.expectTagResourceCalls, tagCalls) + }) + } +} + +// TestGenerateKeyTagBased verifies that keys created while tag-based discovery +// is enabled are stamped at creation with the SPIRE discovery tags, including +// spire-last-update, so they are immediately eligible for staleness evaluation +// and never left undisposable if the server dies before keepActiveKeys runs. +func TestGenerateKeyTagBased(t *testing.T) { + ts := setupTest(t) + + _, err := ts.plugin.Configure(ctx, configureTagBasedRequest(t)) + require.NoError(t, err) + + _, err = ts.plugin.GenerateKey(ctx, &keymanagerv1.GenerateKeyRequest{ + KeyId: spireKeyID, + KeyType: keymanagerv1.KeyType_EC_P256, + }) + require.NoError(t, err) + + ts.fakeKMSClient.mu.RLock() + createCalls := ts.fakeKMSClient.createKeyCalls + ts.fakeKMSClient.mu.RUnlock() + require.Len(t, createCalls, 1) + + tags := make(map[string]string, len(createCalls[0].Tags)) + for _, tag := range createCalls[0].Tags { + require.NotNil(t, tag.TagKey) + require.NotNil(t, tag.TagValue) + tags[*tag.TagKey] = *tag.TagValue + } + + require.Equal(t, "test.example.org", tags[tagKeyServerTD]) + require.Equal(t, validServerID, tags[tagKeyServerID]) + require.Equal(t, "true", tags[tagKeyActive]) + require.Equal(t, spireKeyID, tags[tagKeySPIREKeyID]) + require.Equal(t, strconv.FormatInt(ts.clockHook.Now().Unix(), 10), tags[tagKeyLastUpdate]) +} + +// TestFetchKeyEntryDetailsFromArn covers the defensive error branches of +// tag-based key detail retrieval that are not exercised by the higher-level +// Configure tests. +func TestFetchKeyEntryDetailsFromArn(t *testing.T) { + const ( + fdKeyID = "fd-key-01" + fdKeyArn = fakeKeyArnPrefix + fdKeyID + ) + + for _, tt := range []struct { + name string + fakeEntries []fakeKeyEntry + describeKeyErr string + describeMalformed bool + getPublicKeyErr string + expectErr string + expectNilEntry bool + }{ + { + name: "describe key error", + describeKeyErr: "describe boom", + expectErr: "failed to describe key: describe boom", + }, + { + name: "malformed describe response", + describeMalformed: true, + expectErr: "malformed describe key response", + }, + { + name: "disabled key is skipped", + fakeEntries: []fakeKeyEntry{ + {KeyID: aws.String(fdKeyID), KeySpec: types.KeySpecEccNistP256, Enabled: false, PublicKey: []byte("fake-public-key")}, + }, + expectNilEntry: true, + }, + { + name: "unsupported key spec", + fakeEntries: []fakeKeyEntry{ + {KeyID: aws.String(fdKeyID), KeySpec: types.KeySpec("BOGUS_SPEC"), Enabled: true, PublicKey: []byte("fake-public-key")}, + }, + expectErr: "unsupported key spec", + }, + { + name: "get public key error", + fakeEntries: []fakeKeyEntry{ + {KeyID: aws.String(fdKeyID), KeySpec: types.KeySpecEccNistP256, Enabled: true, PublicKey: []byte("fake-public-key")}, + }, + getPublicKeyErr: "getpub boom", + expectErr: "failed to get public key: getpub boom", + }, + { + name: "malformed get public key response", + fakeEntries: []fakeKeyEntry{ + {KeyID: aws.String(fdKeyID), KeySpec: types.KeySpecEccNistP256, Enabled: true, PublicKey: nil}, + }, + expectErr: "malformed get public key response", + }, + } { + t.Run(tt.name, func(t *testing.T) { + ts := setupTest(t) + ts.fakeKMSClient.setEntries(tt.fakeEntries) + ts.fakeKMSClient.setDescribeKeyErr(tt.describeKeyErr) + ts.fakeKMSClient.setDescribeKeyMalformed(tt.describeMalformed) + ts.fakeKMSClient.setgetPublicKeyErr(tt.getPublicKeyErr) + + kf := &keyFetcher{ + log: ts.plugin.log, + kmsClient: ts.fakeKMSClient, + serverID: validServerID, + trustDomain: "test.example.org", + } + + entry, err := kf.fetchKeyEntryDetailsFromArn(ctx, fdKeyArn, spireKeyID) + if tt.expectErr != "" { + require.ErrorContains(t, err, tt.expectErr) + require.Nil(t, entry) + return + } + require.NoError(t, err) + if tt.expectNilEntry { + require.Nil(t, entry) + return + } + require.NotNil(t, entry) + }) + } +} + +func TestKeepActiveKeys(t *testing.T) { + for _, tt := range []struct { + name string + err string + fakeEntries []fakeKeyEntry + tagResourceErr string + }{ + { + name: "updates spire-last-update tag on all managed entries", + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.KeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("fake-public-key"), + }, + { + AliasName: aws.String(aliasName + "01"), + KeyID: aws.String(keyID + "01"), + KeySpec: types.KeySpecEccNistP384, + Enabled: true, + PublicKey: []byte("fake-public-key"), + }, + }, + }, + { + name: "tag update errors are returned", + fakeEntries: []fakeKeyEntry{ + { + AliasName: aws.String(aliasName), + KeyID: aws.String(keyID), + KeySpec: types.KeySpecEccNistP256, + Enabled: true, + PublicKey: []byte("fake-public-key"), + }, + }, + tagResourceErr: "tag update failed", + err: "tag update failed", + }, + } { + t.Run(tt.name, func(t *testing.T) { + ts := setupTest(t) + ts.fakeKMSClient.setEntries(tt.fakeEntries) + + keepActiveKeysSignal := make(chan error) + ts.plugin.hooks.keepActiveKeysSignal = keepActiveKeysSignal + + _, err := ts.plugin.Configure(ctx, configureTagBasedRequest(t)) + require.NoError(t, err) + + // Wait for initial signal (no-op before the first tick) + _ = waitForSignal(t, keepActiveKeysSignal) + + // Reset calls accumulated during Configure's migration path. + ts.fakeKMSClient.mu.Lock() + ts.fakeKMSClient.tagResourceCalls = nil + ts.fakeKMSClient.mu.Unlock() + + ts.fakeKMSClient.setTagResourceErr(tt.tagResourceErr) + + ts.clockHook.Add(keepActiveKeysFrequency) + err = waitForSignal(t, keepActiveKeysSignal) + + if tt.err != "" { + require.EqualError(t, err, tt.err) + return + } + require.NoError(t, err) + + // One TagResource call per entry with spire-last-update + ts.fakeKMSClient.mu.RLock() + calls := ts.fakeKMSClient.tagResourceCalls + ts.fakeKMSClient.mu.RUnlock() + require.Len(t, calls, len(tt.fakeEntries)) + + expectedTime := strconv.FormatInt(ts.clockHook.Now().Unix(), 10) + for _, call := range calls { + require.Len(t, call.Tags, 1) + require.Equal(t, tagKeyLastUpdate, *call.Tags[0].TagKey) + require.Equal(t, expectedTime, *call.Tags[0].TagValue) + } + }) + } +} + +func TestDisposeKeysViaTags(t *testing.T) { + const ( + otherServerID = "other-server-id" + otherKeyArn = fakeKeyArnPrefix + "other-server-key" + staleTimestamp = int64(0) // Unix epoch — always older than the 2-week threshold + ) + + for _, tt := range []struct { + name string + err string + taggedResources []rgtatypes.ResourceTagMapping + getResourcesErr error + useRecentTimestamp bool + expectDeleteCount int + expectInactiveTagCount int + }{ + { + name: "disposes stale key from another server", + taggedResources: []rgtatypes.ResourceTagMapping{ + { + ResourceARN: aws.String(otherKeyArn), + Tags: []rgtatypes.Tag{ + {Key: aws.String(tagKeyServerTD), Value: aws.String("test.example.org")}, + {Key: aws.String(tagKeyServerID), Value: aws.String(otherServerID)}, + {Key: aws.String(tagKeyActive), Value: aws.String("true")}, + {Key: aws.String(tagKeyLastUpdate), Value: aws.String(strconv.FormatInt(staleTimestamp, 10))}, + }, + }, + }, + expectDeleteCount: 1, + expectInactiveTagCount: 1, + }, + { + name: "skips key belonging to this server", + taggedResources: []rgtatypes.ResourceTagMapping{ + { + ResourceARN: aws.String(otherKeyArn), + Tags: []rgtatypes.Tag{ + {Key: aws.String(tagKeyServerTD), Value: aws.String("test.example.org")}, + {Key: aws.String(tagKeyServerID), Value: aws.String(validServerID)}, + {Key: aws.String(tagKeyActive), Value: aws.String("true")}, + {Key: aws.String(tagKeyLastUpdate), Value: aws.String(strconv.FormatInt(staleTimestamp, 10))}, + }, + }, + }, + expectDeleteCount: 0, + expectInactiveTagCount: 0, + }, + { + // Every key managed by the plugin is stamped with spire-last-update + // at creation and migration, so an active key from another server + // that lacks the tag is treated as abandoned and disposed. + name: "disposes key with no spire-last-update tag", + taggedResources: []rgtatypes.ResourceTagMapping{ + { + ResourceARN: aws.String(otherKeyArn), + Tags: []rgtatypes.Tag{ + {Key: aws.String(tagKeyServerTD), Value: aws.String("test.example.org")}, + {Key: aws.String(tagKeyServerID), Value: aws.String(otherServerID)}, + {Key: aws.String(tagKeyActive), Value: aws.String("true")}, + // no spire-last-update tag + }, + }, + }, + expectDeleteCount: 1, + expectInactiveTagCount: 1, + }, + { + name: "skips key with malformed spire-last-update tag", + taggedResources: []rgtatypes.ResourceTagMapping{ + { + ResourceARN: aws.String(otherKeyArn), + Tags: []rgtatypes.Tag{ + {Key: aws.String(tagKeyServerTD), Value: aws.String("test.example.org")}, + {Key: aws.String(tagKeyServerID), Value: aws.String(otherServerID)}, + {Key: aws.String(tagKeyActive), Value: aws.String("true")}, + {Key: aws.String(tagKeyLastUpdate), Value: aws.String("not-a-number")}, + }, + }, + }, + expectDeleteCount: 0, + expectInactiveTagCount: 0, + }, + { + name: "skips recently updated key", + useRecentTimestamp: true, + expectDeleteCount: 0, + expectInactiveTagCount: 0, + }, + { + name: "GetResources error is returned", + getResourcesErr: errors.New("tagging API error"), + err: "tagging API error", + }, + } { + t.Run(tt.name, func(t *testing.T) { + ts := setupTest(t) + + taggedResources := tt.taggedResources + if tt.useRecentTimestamp { + // After advancing by keyThresholdForTagDiscovery, staleThreshold = epoch. + // Use epoch+1 so the key is strictly newer than the threshold. + recentTimestamp := ts.clockHook.Now().Unix() + 1 + taggedResources = []rgtatypes.ResourceTagMapping{ + { + ResourceARN: aws.String(otherKeyArn), + Tags: []rgtatypes.Tag{ + {Key: aws.String(tagKeyServerTD), Value: aws.String("test.example.org")}, + {Key: aws.String(tagKeyServerID), Value: aws.String(otherServerID)}, + {Key: aws.String(tagKeyActive), Value: aws.String("true")}, + {Key: aws.String(tagKeyLastUpdate), Value: aws.String(strconv.FormatInt(recentTimestamp, 10))}, + }, + }, + } + } + + ts.fakeTaggingClient.setResources(taggedResources) + + // Block dispose-aliases task so it does not interfere + ts.plugin.hooks.disposeAliasesSignal = make(chan error) + disposeKeysSignal := make(chan error) + ts.plugin.hooks.disposeKeysSignal = disposeKeysSignal + deleteSignal := make(chan error) + ts.plugin.hooks.scheduleDeleteSignal = deleteSignal + + _, err := ts.plugin.Configure(ctx, configureTagBasedRequest(t)) + require.NoError(t, err) + + // Consume initial (no-op) signal before Configure returns. + _ = waitForSignal(t, disposeKeysSignal) + + // Set the error after Configure so Configure itself succeeds. + ts.fakeTaggingClient.setErr(tt.getResourcesErr) + + // Advance far enough that the stale threshold (now - 2 weeks) reaches + // epoch, making keys with lastUpdate=0 eligible for disposal. + ts.clockHook.Add(keyThresholdForTagDiscovery) + // Consume the first tick (early clock time, no keys stale yet). + _ = waitForSignal(t, disposeKeysSignal) + // Second tick runs with clk.Now() at the full advance. + err = waitForSignal(t, disposeKeysSignal) + + if tt.err != "" { + require.EqualError(t, err, tt.err) + return + } + require.NoError(t, err) + + // Wait for expected deletions to be processed + for range tt.expectDeleteCount { + _ = waitForSignal(t, deleteSignal) + } + + // Verify spire-active=false was set for each disposed key + ts.fakeKMSClient.mu.RLock() + calls := ts.fakeKMSClient.tagResourceCalls + ts.fakeKMSClient.mu.RUnlock() + + var inactiveTagCount int + for _, call := range calls { + for _, tag := range call.Tags { + if tag.TagKey != nil && *tag.TagKey == tagKeyActive && + tag.TagValue != nil && *tag.TagValue == "false" { + inactiveTagCount++ + } + } + } + require.Equal(t, tt.expectInactiveTagCount, inactiveTagCount) + }) + } +} + +// TestDisposeKeysViaTagsFullDeleteQueue verifies that when the scheduleDelete +// queue is full, a stale key is left spire-active=true (not marked inactive) +// so it is retried on the next cycle. Marking it inactive would drop it from +// the GetResources(active=true) filter without ever enqueueing it for +// deletion, and tag mode has no creation-date sweeper to reclaim it. +func TestDisposeKeysViaTagsFullDeleteQueue(t *testing.T) { + ts := setupTest(t) + ts.plugin.kmsClient = ts.fakeKMSClient + ts.plugin.taggingClient = ts.fakeTaggingClient + ts.plugin.serverID = validServerID + ts.plugin.trustDomain = "test.example.org" + + // Fill the delete queue to capacity so the enqueue hits the default branch. + ts.plugin.scheduleDelete = make(chan string, 1) + ts.plugin.scheduleDelete <- fakeKeyArnPrefix + "prefill" + + staleKeyArn := fakeKeyArnPrefix + "stale-key" + ts.fakeTaggingClient.setResources([]rgtatypes.ResourceTagMapping{ + { + ResourceARN: aws.String(staleKeyArn), + Tags: []rgtatypes.Tag{ + {Key: aws.String(tagKeyServerTD), Value: aws.String("test.example.org")}, + {Key: aws.String(tagKeyServerID), Value: aws.String("other-server-id")}, + {Key: aws.String(tagKeyActive), Value: aws.String("true")}, + {Key: aws.String(tagKeyLastUpdate), Value: aws.String("0")}, + }, + }, + }) + + // Advance so staleThreshold (now - 2 weeks) reaches the Unix epoch, making + // the lastUpdate=0 key eligible for disposal. + ts.clockHook.Add(keyThresholdForTagDiscovery) + + err := ts.plugin.disposeKeysViaTags(ctx) + require.NoError(t, err) + + // The key must not be marked inactive, since it was never enqueued. + ts.fakeKMSClient.mu.RLock() + calls := ts.fakeKMSClient.tagResourceCalls + ts.fakeKMSClient.mu.RUnlock() + for _, call := range calls { + for _, tag := range call.Tags { + if tag.TagKey != nil && *tag.TagKey == tagKeyActive { + require.Fail(t, "stale key should not be marked inactive when the delete queue is full") + } + } + } +} diff --git a/pkg/server/plugin/keymanager/awskms/client.go b/pkg/server/plugin/keymanager/awskms/client.go index cad325bc93..eadff1c88f 100644 --- a/pkg/server/plugin/keymanager/awskms/client.go +++ b/pkg/server/plugin/keymanager/awskms/client.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" "github.com/aws/aws-sdk-go-v2/service/sts" ) @@ -21,6 +22,11 @@ type kmsClient interface { Sign(context.Context, *kms.SignInput, ...func(*kms.Options)) (*kms.SignOutput, error) ListKeys(context.Context, *kms.ListKeysInput, ...func(*kms.Options)) (*kms.ListKeysOutput, error) DeleteAlias(context.Context, *kms.DeleteAliasInput, ...func(*kms.Options)) (*kms.DeleteAliasOutput, error) + TagResource(context.Context, *kms.TagResourceInput, ...func(*kms.Options)) (*kms.TagResourceOutput, error) +} + +type taggingClient interface { + GetResources(context.Context, *resourcegroupstaggingapi.GetResourcesInput, ...func(*resourcegroupstaggingapi.Options)) (*resourcegroupstaggingapi.GetResourcesOutput, error) } type stsClient interface { @@ -31,6 +37,10 @@ func newKMSClient(c aws.Config) (kmsClient, error) { return kms.NewFromConfig(c), nil } +func newTaggingClient(c aws.Config) (taggingClient, error) { + return resourcegroupstaggingapi.NewFromConfig(c), nil +} + func newSTSClient(c aws.Config) (stsClient, error) { return sts.NewFromConfig(c), nil } diff --git a/pkg/server/plugin/keymanager/awskms/client_fake.go b/pkg/server/plugin/keymanager/awskms/client_fake.go index 500a616518..6ae9fb02b1 100644 --- a/pkg/server/plugin/keymanager/awskms/client_fake.go +++ b/pkg/server/plugin/keymanager/awskms/client_fake.go @@ -20,6 +20,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" + rgtatypes "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi/types" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/spiffe/spire/test/testkey" "github.com/stretchr/testify/require" @@ -44,6 +46,9 @@ type kmsClientFake struct { signErr error listKeysErr error deleteAliasErr error + tagResourceErr error + tagResourceCalls []kms.TagResourceInput + createKeyCalls []kms.CreateKeyInput expectedKeyPolicy *string } @@ -54,6 +59,12 @@ type stsClientFake struct { err string } +type taggingClientFake struct { + mu sync.RWMutex + resources []rgtatypes.ResourceTagMapping + err error +} + func newKMSClientFake(t *testing.T, c *clock.Mock) *kmsClientFake { return &kmsClientFake{ t: t, @@ -69,6 +80,33 @@ func newSTSClientFake() *stsClientFake { return &stsClientFake{} } +func newTaggingClientFake() *taggingClientFake { + return &taggingClientFake{} +} + +func (tc *taggingClientFake) GetResources(_ context.Context, _ *resourcegroupstaggingapi.GetResourcesInput, _ ...func(*resourcegroupstaggingapi.Options)) (*resourcegroupstaggingapi.GetResourcesOutput, error) { + tc.mu.RLock() + defer tc.mu.RUnlock() + if tc.err != nil { + return nil, tc.err + } + return &resourcegroupstaggingapi.GetResourcesOutput{ + ResourceTagMappingList: tc.resources, + }, nil +} + +func (tc *taggingClientFake) setResources(resources []rgtatypes.ResourceTagMapping) { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.resources = resources +} + +func (tc *taggingClientFake) setErr(err error) { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.err = err +} + func (s *stsClientFake) GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { if s.err != "" { return nil, errors.New(s.err) @@ -97,12 +135,14 @@ func (k *kmsClientFake) setExpectedKeyPolicy(keyPolicy *string) { } func (k *kmsClientFake) CreateKey(_ context.Context, input *kms.CreateKeyInput, _ ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { - k.mu.RLock() - defer k.mu.RUnlock() + k.mu.Lock() + defer k.mu.Unlock() if k.createKeyErr != nil { return nil, k.createKeyErr } + k.createKeyCalls = append(k.createKeyCalls, *input) + switch k.expectedKeyPolicy { case nil: require.Nil(k.t, input.Policy) @@ -473,6 +513,25 @@ func (k *kmsClientFake) setDeleteAliasErr(fakeError string) { } } +func (k *kmsClientFake) TagResource(_ context.Context, input *kms.TagResourceInput, _ ...func(*kms.Options)) (*kms.TagResourceOutput, error) { + k.mu.Lock() + defer k.mu.Unlock() + if k.tagResourceErr != nil { + return nil, k.tagResourceErr + } + call := *input + k.tagResourceCalls = append(k.tagResourceCalls, call) + return &kms.TagResourceOutput{}, nil +} + +func (k *kmsClientFake) setTagResourceErr(fakeError string) { + k.mu.Lock() + defer k.mu.Unlock() + if fakeError != "" { + k.tagResourceErr = errors.New(fakeError) + } +} + const ( fakeKeyArnPrefix = "arn:aws:kms:region:1234:key/" fakeAliasArnPrefix = "arn:aws:kms:region:1234:" diff --git a/pkg/server/plugin/keymanager/awskms/fetcher.go b/pkg/server/plugin/keymanager/awskms/fetcher.go index 504dc5f5b0..43fd476fc9 100644 --- a/pkg/server/plugin/keymanager/awskms/fetcher.go +++ b/pkg/server/plugin/keymanager/awskms/fetcher.go @@ -2,6 +2,7 @@ package awskms import ( "context" + "errors" "path" "strings" "sync" @@ -9,6 +10,9 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" + rgtatypes "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi/types" + "github.com/aws/smithy-go" "github.com/hashicorp/go-hclog" keymanagerv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/keymanager/v1" "golang.org/x/sync/errgroup" @@ -17,13 +21,16 @@ import ( ) type keyFetcher struct { - log hclog.Logger - kmsClient kmsClient - serverID string - trustDomain string + log hclog.Logger + kmsClient kmsClient + taggingClient taggingClient + serverID string + trustDomain string } -func (kf *keyFetcher) fetchKeyEntries(ctx context.Context) ([]*keyEntry, error) { +// fetchKeyEntriesViaAlias uses the legacy alias-based discovery method. +// This approach lists all aliases and filters by the SPIRE prefix pattern. +func (kf *keyFetcher) fetchKeyEntriesViaAlias(ctx context.Context) ([]*keyEntry, error) { var keyEntries []*keyEntry var keyEntriesMutex sync.Mutex paginator := kms.NewListAliasesPaginator(kf.kmsClient, &kms.ListAliasesInput{Limit: aws.Int32(100)}) @@ -142,3 +149,238 @@ func (kf *keyFetcher) spireKeyIDFromAlias(aliasName string) (string, bool) { } return decodeKeyID(trimmed), true } + +// fetchKeyEntriesViaTag uses AWS Resource Groups Tagging API for efficient key discovery. +// This approach filters keys by SPIRE-specific tags similar to how GCP KMS uses labels. +// Keys are filtered by: trust domain, server ID, and active status. +func (kf *keyFetcher) fetchKeyEntriesViaTag(ctx context.Context) ([]*keyEntry, error) { + var keyEntries []*keyEntry + var keyEntriesMutex sync.Mutex + g, ctx := errgroup.WithContext(ctx) + + // Build tag filters to find only keys belonging to this server + // Unlike GCP, AWS supports dots in tag values, so we use the trust domain directly + tagFilters := []rgtatypes.TagFilter{ + { + Key: aws.String(tagKeyServerTD), + Values: []string{kf.trustDomain}, + }, + { + Key: aws.String(tagKeyServerID), + Values: []string{kf.serverID}, + }, + { + Key: aws.String(tagKeyActive), + Values: []string{"true"}, + }, + } + + // Use pagination to handle large numbers of keys + paginator := resourcegroupstaggingapi.NewGetResourcesPaginator(kf.taggingClient, &resourcegroupstaggingapi.GetResourcesInput{ + ResourceTypeFilters: []string{"kms:key"}, + TagFilters: tagFilters, + }) + + for { + resourcesResp, err := paginator.NextPage(ctx) + switch { + case err != nil: + if permErr := tagGetResourcesPermissionError(err); permErr != nil { + return nil, permErr + } + return nil, status.Errorf(codes.Internal, "failed to fetch keys by tags: %v", err) + case resourcesResp == nil: + return nil, status.Error(codes.Internal, "failed to fetch keys by tags: nil response") + } + + kf.log.Debug("Found keys with SPIRE tags", "num_keys", len(resourcesResp.ResourceTagMappingList)) + + for _, resource := range resourcesResp.ResourceTagMappingList { + if resource.ResourceARN == nil { + continue + } + + keyArn := *resource.ResourceARN + + // Extract SPIRE key ID from tags + spireKeyID, ok := kf.spireKeyIDFromTags(resource.Tags) + if !ok { + kf.log.Warn("Could not get SPIRE key ID from tags", "key_arn", keyArn) + continue + } + + // Trigger a goroutine to get the details of the key + g.Go(func() error { + entry, err := kf.fetchKeyEntryDetailsFromArn(ctx, keyArn, spireKeyID) + if err != nil { + return err + } + if entry == nil { + return nil + } + + keyEntriesMutex.Lock() + keyEntries = append(keyEntries, entry) + keyEntriesMutex.Unlock() + return nil + }) + } + + if !paginator.HasMorePages() { + break + } + } + + // Wait for all the detail gathering routines to finish + if err := g.Wait(); err != nil { + statusErr := status.Convert(err) + return nil, status.Errorf(statusErr.Code(), "failed to fetch key entries: %v", statusErr.Message()) + } + + return keyEntries, nil +} + +// fetchKeyEntryDetailsFromArn retrieves key details using a key ARN directly. +// This is used for tag-based discovery where we get the ARN from the tagging API. +func (kf *keyFetcher) fetchKeyEntryDetailsFromArn(ctx context.Context, keyArn string, spireKeyID string) (*keyEntry, error) { + describeResp, err := kf.kmsClient.DescribeKey(ctx, &kms.DescribeKeyInput{KeyId: &keyArn}) + switch { + case err != nil: + return nil, status.Errorf(codes.Internal, "failed to describe key: %v", err) + case describeResp == nil || describeResp.KeyMetadata == nil: + return nil, status.Error(codes.Internal, "malformed describe key response") + case describeResp.KeyMetadata.Arn == nil: + return nil, status.Errorf(codes.Internal, "found SPIRE key without arn: %q", keyArn) + case !describeResp.KeyMetadata.Enabled: + // Key is disabled or pending deletion. This can happen when a key + // was scheduled for deletion by the alias-based disposal path, + // which does not clear SPIRE tags. Skip the key gracefully. + kf.log.Warn("Skipping disabled SPIRE key found via tags", keyArnTag, keyArn) + return nil, nil + } + + keyType, ok := keyTypeFromKeySpec(describeResp.KeyMetadata.KeySpec) + if !ok { + return nil, status.Errorf(codes.Internal, "unsupported key spec: %v", describeResp.KeyMetadata.KeySpec) + } + + publicKeyResp, err := kf.kmsClient.GetPublicKey(ctx, &kms.GetPublicKeyInput{KeyId: &keyArn}) + switch { + case err != nil: + return nil, status.Errorf(codes.Internal, "failed to get public key: %v", err) + case publicKeyResp == nil || publicKeyResp.PublicKey == nil || len(publicKeyResp.PublicKey) == 0: + return nil, status.Error(codes.Internal, "malformed get public key response") + } + + // Build the expected alias name for this key. Even though we're using tag-based discovery, + // aliases are still created for all keys (for human-readable names in AWS console). + trustDomain := sanitizeTrustDomain(kf.trustDomain) + aliasName := path.Join(aliasPrefix, trustDomain, kf.serverID, encodeKeyID(spireKeyID)) + + return &keyEntry{ + Arn: *describeResp.KeyMetadata.Arn, + AliasName: aliasName, + PublicKey: &keymanagerv1.PublicKey{ + Id: spireKeyID, + Type: keyType, + PkixData: publicKeyResp.PublicKey, + Fingerprint: makeFingerprint(publicKeyResp.PublicKey), + }, + }, nil +} + +// fetchKeyEntriesWithMigration performs tag-based discovery with automatic +// migration of pre-existing untagged keys. It fetches keys via both tags and +// aliases, then applies SPIRE tags to any alias-discovered keys that don't +// have them yet. This allows a transparent one-time migration from alias-based +// to tag-based discovery without any manual steps. lastUpdate is the +// spire-last-update tag value stamped on migrated keys so they are immediately +// eligible for staleness evaluation. +func (kf *keyFetcher) fetchKeyEntriesWithMigration(ctx context.Context, spireTags []types.Tag, lastUpdate string) ([]*keyEntry, error) { + taggedEntries, err := kf.fetchKeyEntriesViaTag(ctx) + if err != nil { + return nil, err + } + + // Also run alias-based discovery to catch pre-existing keys that were + // created before tag-based discovery was enabled. + aliasEntries, err := kf.fetchKeyEntriesViaAlias(ctx) + if err != nil { + return nil, err + } + + taggedIDs := make(map[string]bool, len(taggedEntries)) + for _, e := range taggedEntries { + taggedIDs[e.PublicKey.Id] = true + } + + var migratedCount int + for _, entry := range aliasEntries { + if taggedIDs[entry.PublicKey.Id] { + continue + } + + kf.log.Info("Applying SPIRE tags to legacy key during migration to tag-based discovery", + keyArnTag, entry.Arn) + + // Build a fresh slice to avoid mutating the shared spireTags backing array. + tags := append(append([]types.Tag(nil), spireTags...), + types.Tag{ + TagKey: aws.String(tagKeySPIREKeyID), + TagValue: aws.String(entry.PublicKey.Id), + }, + types.Tag{ + TagKey: aws.String(tagKeyLastUpdate), + TagValue: aws.String(lastUpdate), + }, + ) + if _, err := kf.kmsClient.TagResource(ctx, &kms.TagResourceInput{ + KeyId: &entry.Arn, + Tags: tags, + }); err != nil { + // Don't fail startup. The key is still usable and tagging will + // be retried on the next server restart. + kf.log.Warn("Failed to apply SPIRE tags to legacy key during migration; key will still be available", + keyArnTag, entry.Arn, reasonTag, err) + } + + migratedCount++ + taggedEntries = append(taggedEntries, entry) + } + + if migratedCount > 0 { + kf.log.Info("Tag-based key discovery migration finished", + "migrated_keys", migratedCount, "total_keys", len(taggedEntries)) + } else { + kf.log.Debug("No legacy keys required migration to tag-based discovery", + "total_keys", len(taggedEntries)) + } + + return taggedEntries, nil +} + +// spireKeyIDFromTags extracts the SPIRE key ID from a resource's tags. +func (kf *keyFetcher) spireKeyIDFromTags(tags []rgtatypes.Tag) (string, bool) { + for _, tag := range tags { + if tag.Key != nil && *tag.Key == tagKeySPIREKeyID && tag.Value != nil { + return *tag.Value, true + } + } + return "", false +} + +// tagGetResourcesPermissionError returns an actionable error when err indicates +// the identity is not authorized to call tag:GetResources. Tag-based key +// discovery relies on the Resource Groups Tagging API, whose permissions are +// identity-based and cannot be granted through the KMS key policy. Returns nil +// when err is not an access-denied error. +func tagGetResourcesPermissionError(err error) error { + var apiErr smithy.APIError + if errors.As(err, &apiErr) && apiErr.ErrorCode() == "AccessDeniedException" { + return status.Errorf(codes.FailedPrecondition, + "tag-based key discovery requires the \"tag:GetResources\" permission in an "+ + "identity-based IAM policy (it cannot be granted through the KMS key policy); "+ + "grant the permission or disable enable_tag_based_key_discovery: %v", err) + } + return nil +} From 63c9901728e2cc61cdf299a66bf3bd65654dc929 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Mart=C3=ADnez=20Fay=C3=B3?= Date: Sun, 31 May 2026 18:20:55 -0300 Subject: [PATCH 2/2] Address Copilot comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Agustín Martínez Fayó --- pkg/server/plugin/keymanager/awskms/awskms.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/server/plugin/keymanager/awskms/awskms.go b/pkg/server/plugin/keymanager/awskms/awskms.go index 7d92949e13..3c34b2ca0b 100644 --- a/pkg/server/plugin/keymanager/awskms/awskms.go +++ b/pkg/server/plugin/keymanager/awskms/awskms.go @@ -140,8 +140,8 @@ type Config struct { // In a future SPIRE version, this will default to true. The alias-based // approach will be deprecated in a future version and removed in a later one. // - // Note: When enabled, the plugin requires permission to use the - // resourcegroupstaggingapi:GetResources API action. + // Note: When enabled, the plugin requires the tag:GetResources IAM + // permission (from the AWS Resource Groups Tagging API). EnableTagBasedKeyDiscovery bool `hcl:"enable_tag_based_key_discovery" json:"enable_tag_based_key_discovery"` } @@ -1206,8 +1206,9 @@ func buildKeyTags(tags map[string]string) []types.Tag { // at creation time. These tags enable efficient key discovery via the AWS // Resource Groups Tagging API. // -// Note: spire-last-update is intentionally omitted here. It is set exclusively -// by keepActiveKeys, which runs on a regular schedule. +// Note: spire-last-update is intentionally omitted here. It is stamped +// separately at key creation and during migration (with the current +// timestamp) and refreshed on a regular schedule by keepActiveKeys. func (p *Plugin) buildSPIRETags(serverID, trustDomain string) []types.Tag { return []types.Tag{ {