diff --git a/managed/cmd/pmm-managed/main.go b/managed/cmd/pmm-managed/main.go index b944c86e67..25fcf074d2 100644 --- a/managed/cmd/pmm-managed/main.go +++ b/managed/cmd/pmm-managed/main.go @@ -1152,9 +1152,7 @@ func main() { //nolint:gocognit,maintidx,cyclop updater.Run(ctx) }) - wg.Add(1) haService.AddLeaderService(ha.NewContextService("telemetry", func(ctx context.Context) error { - defer wg.Done() telemetry.Run(ctx) return nil })) diff --git a/managed/services/backup/removal_service.go b/managed/services/backup/removal_service.go index 0da7df977f..bfebdd175d 100644 --- a/managed/services/backup/removal_service.go +++ b/managed/services/backup/removal_service.go @@ -163,7 +163,13 @@ func (s *RemovalService) TrimPITRArtifact(storage Storage, artifactID string, fi return } - err = s.deleteArtifactPITRChunks(context.Background(), storage, location, artifact, artifact.MetadataList[0].RestoreTo) + // After trimming, MetadataList may be empty (firstN covered all + // records); a nil "until" then deletes all remaining PITR chunks. + var until *time.Time + if len(artifact.MetadataList) > 0 { + until = artifact.MetadataList[0].RestoreTo + } + err = s.deleteArtifactPITRChunks(context.Background(), storage, location, artifact, until) if err != nil { s.l.WithError(err).Error("couldn't delete artifact PITR chunks") return diff --git a/managed/services/backup/removal_service_test.go b/managed/services/backup/removal_service_test.go index 1956f5dd91..17a31fa1f3 100644 --- a/managed/services/backup/removal_service_test.go +++ b/managed/services/backup/removal_service_test.go @@ -371,6 +371,46 @@ func TestTrimPITRArtifact(t *testing.T) { assert.Len(t, artifact.MetadataList, 2) }) + t.Run("trimming all remaining metadata", func(t *testing.T) { + chunksRet := []*oplogChunk{ + {FName: "chunkA"}, + } + + mockedStorage.On("RemoveRecursive", mock.Anything, s3Config.Endpoint, s3Config.AccessKey, s3Config.SecretKey, s3Config.BucketName, "artifact_folder/dir2/"). + Return(nil).Once() + mockedStorage.On("Remove", mock.Anything, s3Config.Endpoint, s3Config.AccessKey, s3Config.SecretKey, s3Config.BucketName, "artifact_folder/file4"). + Return(nil).Once() + mockedStorage.On("Remove", mock.Anything, s3Config.Endpoint, s3Config.AccessKey, s3Config.SecretKey, s3Config.BucketName, "artifact_folder/file5"). + Return(nil).Once() + mockedStorage.On("Remove", mock.Anything, s3Config.Endpoint, s3Config.AccessKey, s3Config.SecretKey, s3Config.BucketName, "artifact_folder/file6"). + Return(nil).Once() + mockedStorage.On("RemoveRecursive", mock.Anything, s3Config.Endpoint, s3Config.AccessKey, s3Config.SecretKey, s3Config.BucketName, "artifact_folder/dir3/"). + Return(nil).Once() + mockedStorage.On("Remove", mock.Anything, s3Config.Endpoint, s3Config.AccessKey, s3Config.SecretKey, s3Config.BucketName, "artifact_folder/file7"). + Return(nil).Once() + mockedStorage.On("Remove", mock.Anything, s3Config.Endpoint, s3Config.AccessKey, s3Config.SecretKey, s3Config.BucketName, "artifact_folder/file8"). + Return(nil).Once() + mockedStorage.On("Remove", mock.Anything, s3Config.Endpoint, s3Config.AccessKey, s3Config.SecretKey, s3Config.BucketName, "artifact_folder/file9"). + Return(nil).Once() + + // All metadata is removed, so there is no remaining restore point and + // every PITR chunk is deleted (until == nil). Trimming must not panic + // indexing an empty MetadataList. + mockedPbmPITRService.On("GetPITRFiles", mock.Anything, mock.Anything, locationRes, mock.Anything, mock.Anything).Return(chunksRet, nil).Once() + mockedStorage.On("Remove", mock.Anything, s3Config.Endpoint, s3Config.AccessKey, s3Config.SecretKey, s3Config.BucketName, "chunkA"). + Return(nil).Once() + + err := removalService.TrimPITRArtifact(mockedStorage, artifact.ID, 2) + require.NoError(t, err) + + time.Sleep(time.Second * 2) + + artifact, err = models.FindArtifactByID(db.Querier, artifact.ID) + require.NoError(t, err) + require.NotNil(t, artifact) + assert.Empty(t, artifact.MetadataList) + }) + mockedStorage.AssertExpectations(t) mockedPbmPITRService.AssertExpectations(t) } diff --git a/managed/services/checks/checks.go b/managed/services/checks/checks.go index c7e66baf05..6d263c9474 100644 --- a/managed/services/checks/checks.go +++ b/managed/services/checks/checks.go @@ -94,6 +94,11 @@ type Service struct { startDelay time.Duration customCheckFile string // For testing + runCtxM sync.Mutex + // runCtx is the service lifecycle context recorded by Run. It bounds + // asynchronous work started via StartChecks so it is cancelled on shutdown. + runCtx context.Context //nolint:containedctx + am sync.Mutex advisors []check.Advisor checks map[string]check.Check // Checks extracted from advisors and stored by name. @@ -134,6 +139,7 @@ func New( l: l, startDelay: defaultStartDelay, customCheckFile: os.Getenv(envCheckFile), + runCtx: context.Background(), mChecksExecuted: prom.NewCounterVec(prom.CounterOpts{ Namespace: prometheusNamespace, @@ -171,6 +177,10 @@ func (s *Service) Run(ctx context.Context) { s.l.Info("Starting...") defer s.l.Info("Done.") + s.runCtxM.Lock() + s.runCtx = ctx + s.runCtxM.Unlock() + s.UpdateAdvisorsList(ctx) settings, err := models.GetSettings(s.db) if err != nil { @@ -178,13 +188,14 @@ func (s *Service) Run(ctx context.Context) { return } + s.tm.Lock() s.rareTicker = time.NewTicker(settings.SaaS.AdvisorRunIntervals.RareInterval) - defer s.rareTicker.Stop() - s.standardTicker = time.NewTicker(settings.SaaS.AdvisorRunIntervals.StandardInterval) - defer s.standardTicker.Stop() - s.frequentTicker = time.NewTicker(settings.SaaS.AdvisorRunIntervals.FrequentInterval) + s.tm.Unlock() + + defer s.rareTicker.Stop() + defer s.standardTicker.Stop() defer s.frequentTicker.Stop() // delay for the first run to allow all agents to connect @@ -275,8 +286,11 @@ func (s *Service) StartChecks(checkNames []string) error { return services.ErrAdvisorsDisabled } + s.runCtxM.Lock() + ctx := s.runCtx + s.runCtxM.Unlock() + go func() { - ctx := context.Background() s.UpdateAdvisorsList(ctx) err := s.run(ctx, "", checkNames) if err != nil { @@ -1650,10 +1664,16 @@ func (s *Service) updateAdvisors(advisors []check.Advisor) { // UpdateIntervals updates advisor checks restart timer intervals. func (s *Service) UpdateIntervals(rare, standard, frequent time.Duration) { s.tm.Lock() + defer s.tm.Unlock() + // Tickers are created by Run; if it has not started on this node (e.g. not + // the leader), there is nothing to reset - Run reads the new intervals from + // the persisted settings when it starts. + if s.rareTicker == nil || s.standardTicker == nil || s.frequentTicker == nil { + return + } s.rareTicker.Reset(rare) s.standardTicker.Reset(standard) s.frequentTicker.Reset(frequent) - s.tm.Unlock() s.l.Infof("Intervals are changed: rare %s, standard %s, frequent %s", rare, standard, frequent) } diff --git a/managed/services/checks/checks_test.go b/managed/services/checks/checks_test.go index d51babe792..3bb5a89e2b 100644 --- a/managed/services/checks/checks_test.go +++ b/managed/services/checks/checks_test.go @@ -336,6 +336,24 @@ func TestStartChecks(t *testing.T) { }) } +func TestNewInitializesRunContext(t *testing.T) { + t.Parallel() + // New must initialize runCtx so StartChecks never passes a nil context + // when invoked before Run records the service lifecycle context. + s := New(nil, nil, nil, nil) + require.NotNil(t, s.runCtx) +} + +func TestUpdateIntervalsBeforeRun(t *testing.T) { + t.Parallel() + // UpdateIntervals must not panic when Run has not created the tickers yet + // (e.g. a settings change on a node that is not the leader). + s := New(nil, nil, nil, nil) + assert.NotPanics(t, func() { + s.UpdateIntervals(time.Hour, time.Minute, time.Second) + }) +} + func TestFilterChecks(t *testing.T) { t.Parallel() diff --git a/managed/services/dump/dump.go b/managed/services/dump/dump.go index 4f52b06b9e..c631e66f90 100644 --- a/managed/services/dump/dump.go +++ b/managed/services/dump/dump.go @@ -326,14 +326,6 @@ func (s *Service) saveLogChunk(dumpID string, chunkN uint32, text string, last b return nil } -// StopDump stops the ongoing dump process in the dump service. -func (s *Service) StopDump() { - s.rw.RLock() - defer s.rw.RUnlock() - - s.cancel() -} - func getDumpFilePath(id string, encrypted bool) string { s := fmt.Sprintf("%s/%s.tar.gz", dumpsDir, id) if encrypted { diff --git a/managed/services/ha/services.go b/managed/services/ha/services.go index e7dea7f0fd..203904ab08 100644 --- a/managed/services/ha/services.go +++ b/managed/services/ha/services.go @@ -74,13 +74,13 @@ func (s *services) StartAllServices(ctx context.Context) { for id, service := range s.all { if _, ok := s.running[id]; !ok { s.running[id] = service + s.wg.Add(1) toStart = append(toStart, startItem{svc: service, id: id}) } } s.rw.Unlock() for _, service := range toStart { - s.wg.Add(1) go func(svc LeaderService, svcID string) { s.l.Infoln("Starting", svcID) err := svc.Start(ctx) @@ -122,7 +122,12 @@ func (s *services) Wait() { // removeService removes a service from the registry of running services. func (s *services) removeService(id string) { s.rw.Lock() + _, ok := s.running[id] delete(s.running, id) s.rw.Unlock() - s.wg.Done() + // Only decrement for a service we actually removed + // to avoid negative counter and panic. + if ok { + s.wg.Done() + } } diff --git a/managed/services/ha/services_test.go b/managed/services/ha/services_test.go index 007c2f9699..24c88e9d3e 100644 --- a/managed/services/ha/services_test.go +++ b/managed/services/ha/services_test.go @@ -350,6 +350,43 @@ func TestServices_Wait(t *testing.T) { }) } +func TestServices_NoDoubleDoneOnStopThenStartError(t *testing.T) { + t.Parallel() + + s := newServices() + svc := &mockLeaderService{id: "svc"} + require.NoError(t, s.Add(svc)) + + // Simulate StartAllServices having registered and counted the service. + s.rw.Lock() + s.running[svc.id] = svc + s.wg.Add(1) + s.rw.Unlock() + + // Leadership is lost: StopAllServices claims the service and balances the + // WaitGroup. + s.StopAllServices() + + // The service's Start then returns an error, so its goroutine removes it. + // removeService must not decrement the WaitGroup again; otherwise the + // counter goes negative and panics. + assert.NotPanics(t, func() { + s.removeService(svc.id) + }) + + // WaitGroup is balanced: Wait returns promptly. + done := make(chan struct{}) + go func() { + s.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("Wait did not return; WaitGroup accounting is unbalanced") + } +} + type mockLeaderService struct { id string started bool diff --git a/managed/services/server/updater.go b/managed/services/server/updater.go index 28caaff3de..bd1894cd19 100644 --- a/managed/services/server/updater.go +++ b/managed/services/server/updater.go @@ -600,12 +600,12 @@ func (up *Updater) getReleaseNotesText(ctx context.Context, version version.Pars up.l.WithError(err).Errorf("Failed to get release note for version: %s", versionString) return "", errors.Wrapf(err, "failed to get release notes for version: %s", versionString) } + defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { up.l.Errorf("Failed to get release notes for PMM %s, got HTTP %d", version.String(), resp.StatusCode) return "", nil } - defer resp.Body.Close() //nolint:errcheck var rnResponse ReleaseNotesResponse err = json.NewDecoder(resp.Body).Decode(&rnResponse) if err != nil { diff --git a/managed/services/server/updater_test.go b/managed/services/server/updater_test.go index b0f1c6aa17..2c142eb2db 100644 --- a/managed/services/server/updater_test.go +++ b/managed/services/server/updater_test.go @@ -17,6 +17,8 @@ package server import ( "context" + "io" + "net/http" "net/url" "os" "path/filepath" @@ -357,3 +359,43 @@ PMM_IMAGE=docker.io/perconalab/pmm-server:3-dev-container PMM_DISTRIBUTION_METHOD=ami`, string(newContent)) }) } + +// trackingReadCloser records whether Close was called. +type trackingReadCloser struct { + io.Reader + + closed bool +} + +func (rc *trackingReadCloser) Close() error { + rc.closed = true + return nil +} + +type stubRoundTripper struct { + statusCode int + body io.ReadCloser +} + +func (rt stubRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: rt.statusCode, + Body: rt.body, + Header: make(http.Header), + }, nil +} + +func TestGetReleaseNotesClosesResponseBody(t *testing.T) { + t.Setenv(env.PlatformAddress, "https://version.test") + + body := &trackingReadCloser{Reader: strings.NewReader("")} + origTransport := http.DefaultClient.Transport + http.DefaultClient.Transport = stubRoundTripper{statusCode: http.StatusNotFound, body: body} + t.Cleanup(func() { http.DefaultClient.Transport = origTransport }) + + u := NewUpdater(nil, 0, nil) + text, err := u.getReleaseNotesText(t.Context(), *version.MustParse("3.0.0")) + require.NoError(t, err) + assert.Empty(t, text) + assert.True(t, body.closed, "response body must be closed on non-200 status") +}