diff --git a/src/agent/agent.go b/src/agent/agent.go index f688d3c..98a0463 100644 --- a/src/agent/agent.go +++ b/src/agent/agent.go @@ -8,6 +8,7 @@ import ( "cattery/lib/messages" "context" "errors" + "fmt" "os" "os/signal" "path" @@ -16,13 +17,15 @@ import ( "github.com/fsnotify/fsnotify" log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" ) var RunnerFolder string var CatteryServerUrl string var Id string -// shutdownCause is used as context.Cause to carry the termination reason. +// shutdownCause is returned by watchers to report why shutdown was requested. +// Satisfies error so it can propagate through errgroup. type shutdownCause struct { reason messages.UnregisterReason message string @@ -56,7 +59,9 @@ func NewCatteryAgent(runnerFolder string, catteryServerUrl string, agentId strin func (a *CatteryAgent) Start() { a.logger.Info("Starting Cattery Agent") - agent, jitConfig, err := a.catteryClient.RegisterAgent(a.agentId) + registerCtx, cancelRegister := context.WithTimeout(context.Background(), 30*time.Second) + agent, jitConfig, err := a.catteryClient.RegisterAgent(registerCtx, a.agentId) + cancelRegister() if err != nil { a.logger.Errorf("Failed to register agent: %v", err) return @@ -65,55 +70,59 @@ func (a *CatteryAgent) Start() { a.logger.Info("Agent registered, starting Listener") - ctx, cancel := context.WithCancelCause(context.Background()) - defer cancel(nil) - - a.watchSignal(ctx, cancel) - a.watchFile(ctx, cancel) - a.watchPing(ctx, cancel) - - var ghListener = githubListener.NewGithubListener(a.listenerExecPath) - ghListener.Start(ctx, cancel, jitConfig) - - // Block until any source triggers cancellation - <-ctx.Done() + // File watcher setup is synchronous so startup fails fast and doesn't + // leak a goroutine or half-initialized fsnotify state. + watcher, err := a.setupFileWatcher() + if err != nil { + a.logger.Errorf("Failed to start file watcher: %v", err) + a.unregisterAndShutdown(messages.UnregisterReasonDone, "file watcher setup: "+err.Error()) + return + } + defer watcher.Close() + + g, ctx := errgroup.WithContext(context.Background()) + + g.Go(func() error { return a.watchSignal(ctx) }) + g.Go(func() error { return a.watchFile(ctx, watcher) }) + g.Go(func() error { return a.watchPing(ctx) }) + g.Go(func() error { + listener := githubListener.NewGithubListener(a.listenerExecPath) + err := listener.Run(ctx, jitConfig) + // Listener exit (clean or otherwise) must cancel the group. Translate + // into a shutdownCause so Wait() returns an error and errgroup signals + // the other watchers. + if err == nil { + return &shutdownCause{reason: messages.UnregisterReasonDone, message: "Listener finished"} + } + return &shutdownCause{reason: messages.UnregisterReasonDone, message: "Listener exited: " + err.Error()} + }) - // Determine what happened - reason, msg := a.resolveShutdownCause(ctx) + reason, msg := a.resolveShutdownCause(g.Wait()) a.logger.Infof("Shutdown: reason=%d, message=%s", reason, msg) - // Kill listener if it wasn't the one that finished - if reason != messages.UnregisterReasonDone { - ghListener.Stop() - } - a.unregisterAndShutdown(reason, msg) } -// resolveShutdownCause extracts the termination reason from the context cause. -// - shutdownCause: a watcher triggered shutdown (signal, file, ping) -// - nil cause: listener exited cleanly -// - other error: listener exited with error -func (a *CatteryAgent) resolveShutdownCause(ctx context.Context) (messages.UnregisterReason, string) { - cause := context.Cause(ctx) - +// resolveShutdownCause unwraps the first error returned by the errgroup. +// Watchers wrap their shutdown reasons in *shutdownCause; anything else is +// treated as a listener-finished signal. +func (a *CatteryAgent) resolveShutdownCause(err error) (messages.UnregisterReason, string) { var sc *shutdownCause - if errors.As(cause, &sc) { + if errors.As(err, &sc) { return sc.reason, sc.message } - - // Listener finished (cancel was called with nil or a process error) - if cause == nil { + if err == nil { return messages.UnregisterReasonDone, "Listener finished" } - return messages.UnregisterReasonDone, "Listener exited: " + cause.Error() + return messages.UnregisterReasonDone, err.Error() } func (a *CatteryAgent) unregisterAndShutdown(reason messages.UnregisterReason, msg string) { log.Infof("Stopping Cattery Agent with reason: %d, message: `%s`", reason, msg) - err := a.catteryClient.UnregisterAgent(a.agent, reason, msg) - if err != nil { + unregisterCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := a.catteryClient.UnregisterAgent(unregisterCtx, a.agent, reason, msg); err != nil { a.logger.Errorf("Failed to unregister agent: %v", err) } @@ -123,93 +132,91 @@ func (a *CatteryAgent) unregisterAndShutdown(reason messages.UnregisterReason, m } } -func (a *CatteryAgent) watchSignal(ctx context.Context, cancel context.CancelCauseFunc) { - go func() { - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - - select { - case <-ctx.Done(): - return - case sig := <-sigs: - a.logger.Info("Got signal ", sig) - cancel(&shutdownCause{ - reason: messages.UnregisterReasonSigTerm, - message: "Got signal " + sig.String(), - }) +func (a *CatteryAgent) watchSignal(ctx context.Context) error { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(sigs) + + select { + case <-ctx.Done(): + return ctx.Err() + case sig := <-sigs: + a.logger.Info("Got signal ", sig) + return &shutdownCause{ + reason: messages.UnregisterReasonSigTerm, + message: "Got signal " + sig.String(), } - }() + } } -func (a *CatteryAgent) watchFile(ctx context.Context, cancel context.CancelCauseFunc) { +func (a *CatteryAgent) setupFileWatcher() (*fsnotify.Watcher, error) { const shutdownFile = "./shutdown_file" - go func() { - watcher, err := fsnotify.NewWatcher() - if err != nil { - a.logger.Fatalf("Failed to create file watcher: %v", err) - } - defer watcher.Close() + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("create file watcher: %w", err) + } - // Create the shutdown file if it doesn't exist - f, err := os.OpenFile(shutdownFile, os.O_CREATE|os.O_APPEND, 0644) - if err != nil { - a.logger.Fatalf("Failed to create shutdown file: %v", err) - } - f.Close() + f, err := os.OpenFile(shutdownFile, os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + watcher.Close() + return nil, fmt.Errorf("create shutdown file: %w", err) + } + f.Close() - if err := watcher.Add(shutdownFile); err != nil { - a.logger.Fatalf("Failed to watch shutdown file: %v", err) - } + if err := watcher.Add(shutdownFile); err != nil { + watcher.Close() + return nil, fmt.Errorf("watch shutdown file: %w", err) + } - select { - case <-ctx.Done(): - return - case event := <-watcher.Events: - msg := "Shutdown file changed: " + event.Name - a.logger.Info(msg) - cancel(&shutdownCause{ - reason: messages.UnregisterReasonPreempted, - message: msg, - }) - case watchErr := <-watcher.Errors: - msg := "File watcher error: " + watchErr.Error() - a.logger.Error(msg) - cancel(&shutdownCause{ - reason: messages.UnregisterReasonPreempted, - message: msg, - }) - } - }() + return watcher, nil } -func (a *CatteryAgent) watchPing(ctx context.Context, cancel context.CancelCauseFunc) { - go func() { - for { - select { - case <-ctx.Done(): - return - default: - } +func (a *CatteryAgent) watchFile(ctx context.Context, watcher *fsnotify.Watcher) error { + select { + case <-ctx.Done(): + return ctx.Err() + case event := <-watcher.Events: + msg := "Shutdown file changed: " + event.Name + a.logger.Info(msg) + return &shutdownCause{reason: messages.UnregisterReasonPreempted, message: msg} + case watchErr := <-watcher.Errors: + msg := "File watcher error: " + watchErr.Error() + a.logger.Error(msg) + return &shutdownCause{reason: messages.UnregisterReasonPreempted, message: msg} + } +} - pingResponse, err := a.catteryClient.Ping() - if err != nil { - a.logger.Errorf("Error pinging controller: %v", err) - time.Sleep(60 * time.Second) - continue - } +func (a *CatteryAgent) watchPing(ctx context.Context) error { + const pingInterval = 60 * time.Second + const pingTimeout = 15 * time.Second + + for { + pingCtx, cancel := context.WithTimeout(ctx, pingTimeout) + pingResponse, err := a.catteryClient.Ping(pingCtx) + cancel() + + // If ctx was cancelled while Ping was in flight, exit without logging + // a spurious transport error. + if ctx.Err() != nil { + return ctx.Err() + } - if pingResponse.Terminate { - msg := "Controller requested termination: " + pingResponse.Message - a.logger.Info(msg) - cancel(&shutdownCause{ - reason: messages.UnregisterReasonControllerKill, - message: msg, - }) - return + if err != nil { + a.logger.Errorf("Error pinging controller: %v", err) + } else if pingResponse.Terminate { + msg := "Controller requested termination: " + pingResponse.Message + a.logger.Info(msg) + return &shutdownCause{ + reason: messages.UnregisterReasonControllerKill, + message: msg, } + } - time.Sleep(60 * time.Second) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(pingInterval): } - }() + } } diff --git a/src/agent/catteryClient/client.go b/src/agent/catteryClient/client.go index 2acf8f2..5d554e4 100644 --- a/src/agent/catteryClient/client.go +++ b/src/agent/catteryClient/client.go @@ -4,15 +4,21 @@ import ( "bytes" "cattery/lib/agents" "cattery/lib/messages" + "context" "encoding/json" "fmt" "io" "net/http" "net/url" + "time" "github.com/sirupsen/logrus" ) +// Per-request timeout applied when the caller supplies a context without a +// deadline. Keeps a dead or unreachable server from wedging the agent. +const defaultRequestTimeout = 30 * time.Second + type CatteryClient struct { httpClient *http.Client baseURL string @@ -22,56 +28,49 @@ type CatteryClient struct { func NewCatteryClient(baseURL string, agentId string) *CatteryClient { return &CatteryClient{ - httpClient: &http.Client{}, + httpClient: &http.Client{Timeout: defaultRequestTimeout}, baseURL: baseURL, logger: logrus.WithField("name", "catteryClient"), agentId: agentId, } } -// RegisterAgent request just-in-time runner configuration from the Cattery server -// and returns the configuration as a base64 encoded string +// RegisterAgent requests just-in-time runner configuration from the Cattery +// server and returns the agent plus its JIT config blob. // // https://docs.github.com/en/rest/actions/self-hosted-runners?apiVersion=2022-11-28#create-configuration-for-a-just-in-time-runner-for-an-organization -func (c *CatteryClient) RegisterAgent(id string) (*agents.Agent, *string, error) { - - client := c.httpClient - +func (c *CatteryClient) RegisterAgent(ctx context.Context, id string) (*agents.Agent, *string, error) { requestUrl, err := url.JoinPath(c.baseURL, "/agent", "register/", id) if err != nil { return nil, nil, err } - request, err := http.NewRequest("GET", requestUrl, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestUrl, nil) if err != nil { - return nil, nil, fmt.Errorf("failed to create request: %w", err) + return nil, nil, fmt.Errorf("create register request: %w", err) } - response, err := client.Do(request) + + response, err := c.httpClient.Do(req) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("register request: %w", err) } - defer response.Body.Close() if response.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(response.Body) - return nil, nil, fmt.Errorf("response status code: %s body: %s", response.Status, string(bodyBytes)) + return nil, nil, fmt.Errorf("register response status %s body: %s", response.Status, string(bodyBytes)) } registerResponse := &messages.RegisterResponse{} - err = json.NewDecoder(response.Body).Decode(registerResponse) - if err != nil { - return nil, nil, err + if err := json.NewDecoder(response.Body).Decode(registerResponse); err != nil { + return nil, nil, fmt.Errorf("decode register response: %w", err) } return ®isterResponse.Agent, ®isterResponse.JitConfig, nil } -// UnregisterAgent sends a POST request to the Cattery server to unregister the agent -func (c *CatteryClient) UnregisterAgent(agent *agents.Agent, reason messages.UnregisterReason, message string) error { - - client := c.httpClient - +// UnregisterAgent tells the server to unregister this agent. +func (c *CatteryClient) UnregisterAgent(ctx context.Context, agent *agents.Agent, reason messages.UnregisterReason, message string) error { requestJson, err := json.Marshal(messages.UnregisterRequest{ Agent: *agent, Reason: reason, @@ -86,52 +85,50 @@ func (c *CatteryClient) UnregisterAgent(agent *agents.Agent, reason messages.Unr return err } - request, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(requestJson)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestUrl, bytes.NewBuffer(requestJson)) if err != nil { - return fmt.Errorf("failed to create request: %w", err) + return fmt.Errorf("create unregister request: %w", err) } - response, err := client.Do(request) + + response, err := c.httpClient.Do(req) if err != nil { - return err + return fmt.Errorf("unregister request: %w", err) } - defer response.Body.Close() if response.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(response.Body) - return fmt.Errorf("response status code: %s body: %s", response.Status, string(bodyBytes)) + return fmt.Errorf("unregister response status %s body: %s", response.Status, string(bodyBytes)) } return nil } -func (c *CatteryClient) Ping() (*messages.PingResponse, error) { - +func (c *CatteryClient) Ping(ctx context.Context) (*messages.PingResponse, error) { requestUrl, err := url.JoinPath(c.baseURL, "/agent", "ping", c.agentId) if err != nil { - return nil, fmt.Errorf("failed to join path: %w", err) + return nil, fmt.Errorf("join path: %w", err) } - request, err := http.NewRequest("POST", requestUrl, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestUrl, nil) if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + return nil, fmt.Errorf("create ping request: %w", err) } - response, err := c.httpClient.Do(request) + + response, err := c.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("post error: %w", err) + return nil, fmt.Errorf("ping request: %w", err) } - defer response.Body.Close() if response.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(response.Body) - return nil, fmt.Errorf("response status code: %s body: %s", response.Status, string(bodyBytes)) + return nil, fmt.Errorf("ping response status %s body: %s", response.Status, string(bodyBytes)) } pingResponse := &messages.PingResponse{} - err = json.NewDecoder(response.Body).Decode(pingResponse) - if err != nil { - return nil, fmt.Errorf("error decoding ping response: %w", err) + if err := json.NewDecoder(response.Body).Decode(pingResponse); err != nil { + return nil, fmt.Errorf("decode ping response: %w", err) } return pingResponse, nil diff --git a/src/agent/githubListener/githubListener.go b/src/agent/githubListener/githubListener.go index 5090edb..36b2937 100644 --- a/src/agent/githubListener/githubListener.go +++ b/src/agent/githubListener/githubListener.go @@ -2,72 +2,65 @@ package githubListener import ( "context" + "fmt" "os" "os/exec" - "sync" + "time" log "github.com/sirupsen/logrus" ) +// gracePeriod is how long the runner is given to exit after an interrupt +// before we escalate to SIGKILL. +const gracePeriod = 10 * time.Second + type GithubListener struct { listenerPath string - process *os.Process - started chan struct{} // closed once process has started (or failed) - - mut sync.Mutex } func NewGithubListener(listenerPath string) *GithubListener { - return &GithubListener{ - listenerPath: listenerPath, - started: make(chan struct{}), - } + return &GithubListener{listenerPath: listenerPath} } -// Start launches the GitHub runner listener in a background goroutine. -// When the process exits, it cancels ctx with the resulting error (nil on success). -func (l *GithubListener) Start(ctx context.Context, cancel context.CancelCauseFunc, jitConfig *string) { - var commandRun = exec.Command(l.listenerPath, "run", "--jitconfig", *jitConfig) - commandRun.Stdout = os.Stdout - commandRun.Stderr = os.Stderr +// Run starts the GitHub runner listener and blocks until either the process +// exits or ctx is cancelled. On cancellation the process is interrupted and, +// if it doesn't exit within gracePeriod, forcefully killed with SIGKILL. +func (l *GithubListener) Run(ctx context.Context, jitConfig *string) error { + cmd := exec.Command(l.listenerPath, "run", "--jitconfig", *jitConfig) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr - go func() { - err := commandRun.Start() - if err != nil { - log.Errorf("Listener failed to start: %v", err) - close(l.started) - cancel(err) - return - } + if err := cmd.Start(); err != nil { + return fmt.Errorf("start listener: %w", err) + } - l.mut.Lock() - l.process = commandRun.Process - l.mut.Unlock() - close(l.started) + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() - err = commandRun.Wait() - cancel(err) // nil means clean exit - }() + select { + case err := <-done: + return err + case <-ctx.Done(): + shutdownProcess(cmd.Process, done) + return ctx.Err() + } } -func (l *GithubListener) Stop() { - <-l.started // wait for process to be set before attempting kill - - l.mut.Lock() - defer l.mut.Unlock() - - if l.process == nil { - return +// shutdownProcess asks the process to exit, escalating to SIGKILL if the +// grace period elapses. Blocks until cmd.Wait() observes the exit. +func shutdownProcess(p *os.Process, done <-chan error) { + if err := interrupt(p); err != nil { + log.Errorf("Failed to interrupt listener: %v", err) } - err := l.kill() - if err != nil { - log.Error("Failed to kill process: ", err) + select { + case <-done: + return + case <-time.After(gracePeriod): + log.Warnf("Listener did not exit within %s, sending SIGKILL", gracePeriod) + if err := p.Kill(); err != nil { + log.Errorf("Failed to kill listener: %v", err) + } + <-done } - - l.process = nil -} - -func (l *GithubListener) kill() error { - return kill(l) } diff --git a/src/agent/githubListener/kill.go b/src/agent/githubListener/kill.go index 6543fa7..592d965 100644 --- a/src/agent/githubListener/kill.go +++ b/src/agent/githubListener/kill.go @@ -7,11 +7,12 @@ import ( "os" ) -func kill(l *GithubListener) error { - err := l.process.Signal(os.Kill) - if err != nil { - return fmt.Errorf("failed to kill process: %w", err) +// interrupt asks the process to exit gracefully. On non-linux platforms we +// don't have a separate graceful signal in use here, so we fall through to +// os.Kill — the caller's grace-period + SIGKILL escalation still applies. +func interrupt(process *os.Process) error { + if err := process.Signal(os.Interrupt); err != nil { + return fmt.Errorf("signal process: %w", err) } - return nil } diff --git a/src/agent/githubListener/kill_linux.go b/src/agent/githubListener/kill_linux.go index 3c2743a..4f84d9e 100644 --- a/src/agent/githubListener/kill_linux.go +++ b/src/agent/githubListener/kill_linux.go @@ -2,22 +2,15 @@ package githubListener import ( "fmt" + "os" "os/exec" ) -func kill(l *GithubListener) error { - var commandInterruptRun = exec.Command("pkill", "--signal", "SIGINT", "Runner.Listener") - err := commandInterruptRun.Run() - if err != nil { - return fmt.Errorf("failed to interrupt runner: %w", err) +// interrupt asks the runner to exit gracefully. Direct SIGINT to the process +// doesn't behave as expected (TODO: investigate), so we pkill by name. +func interrupt(_ *os.Process) error { + if err := exec.Command("pkill", "--signal", "SIGINT", "Runner.Listener").Run(); err != nil { + return fmt.Errorf("pkill Runner.Listener: %w", err) } - return nil - - // TODO: debug why SIGINT does not work correctly - // err := runnerProcess.Signal(syscall.SIGINT) - // if err != nil { - // var errMsg = "Failed to interrupt runner: " + err.Error() - // a.logger.Error(errMsg) - // } } diff --git a/src/lib/trayManager/trayManager.go b/src/lib/trayManager/trayManager.go index 43a36d9..c7452c4 100644 --- a/src/lib/trayManager/trayManager.go +++ b/src/lib/trayManager/trayManager.go @@ -92,7 +92,7 @@ func (tm *TrayManager) CreateTray(ctx context.Context, trayType *config.TrayType return err } - err = provider.RunTray(tray) + err = provider.RunTray(ctx, tray) if err != nil { log.Errorf("Failed to run tray for provider '%s', tray '%s': %v", trayType.Provider, tray.Id, err) metrics.TrayProviderErrors(tray.GitHubOrgName, tray.ProviderName, tray.TrayTypeName, "create") @@ -102,7 +102,7 @@ func (tm *TrayManager) CreateTray(ctx context.Context, trayType *config.TrayType err = tm.trayRepository.Save(ctx, tray) if err != nil { log.Errorf("Failed to save tray %s: %v — cleaning up provider resource", trayType.Name, err) - if cleanErr := provider.CleanTray(tray); cleanErr != nil { + if cleanErr := provider.CleanTray(ctx, tray); cleanErr != nil { log.Errorf("Failed to clean up tray %s after save failure: %v", tray.Id, cleanErr) metrics.TrayProviderErrors(tray.GitHubOrgName, tray.ProviderName, tray.TrayTypeName, "delete") } @@ -168,7 +168,7 @@ func (tm *TrayManager) DeleteTray(ctx context.Context, trayId string) (*trays.Tr return nil, err } - err = provider.CleanTray(tray) + err = provider.CleanTray(ctx, tray) if err != nil { log.Errorf("Failed to delete tray for provider %s, tray %s: %v", provider.GetProviderName(), tray.Id, err) metrics.TrayProviderErrors(tray.GitHubOrgName, tray.ProviderName, tray.TrayTypeName, "delete") diff --git a/src/lib/trayManager/trayManager_test.go b/src/lib/trayManager/trayManager_test.go index 169a8d3..b5005d4 100644 --- a/src/lib/trayManager/trayManager_test.go +++ b/src/lib/trayManager/trayManager_test.go @@ -25,13 +25,13 @@ type mockProvider struct { } func (m *mockProvider) GetProviderName() string { return m.name } -func (m *mockProvider) RunTray(_ *trays.Tray) error { +func (m *mockProvider) RunTray(_ context.Context, _ *trays.Tray) error { m.mu.Lock() defer m.mu.Unlock() m.runCalls++ return m.runErr } -func (m *mockProvider) CleanTray(tray *trays.Tray) error { +func (m *mockProvider) CleanTray(_ context.Context, tray *trays.Tray) error { m.mu.Lock() defer m.mu.Unlock() m.cleaned = append(m.cleaned, tray.Id) diff --git a/src/lib/trays/providers/dockerProvider.go b/src/lib/trays/providers/dockerProvider.go index 6356139..84f25c9 100644 --- a/src/lib/trays/providers/dockerProvider.go +++ b/src/lib/trays/providers/dockerProvider.go @@ -3,6 +3,7 @@ package providers import ( "cattery/lib/config" "cattery/lib/trays" + "context" "fmt" "os/exec" "strings" @@ -33,7 +34,7 @@ func (d *DockerProvider) GetProviderName() string { return d.name } -func (d *DockerProvider) RunTray(tray *trays.Tray) error { +func (d *DockerProvider) RunTray(ctx context.Context, tray *trays.Tray) error { containerName := tray.Id @@ -45,7 +46,7 @@ func (d *DockerProvider) RunTray(tray *trays.Tray) error { image := trayConfig.Image serverUrl := config.Get().Server.AdvertiseUrl - dockerCommand := exec.Command("docker", "run", "-d", "--rm", + dockerCommand := exec.CommandContext(ctx, "docker", "run", "-d", "--rm", "--add-host=host.docker.internal:host-gateway", "--name", containerName, image, @@ -62,8 +63,8 @@ func (d *DockerProvider) RunTray(tray *trays.Tray) error { return nil } -func (d *DockerProvider) CleanTray(tray *trays.Tray) error { - dockerCommand := exec.Command("docker", "container", "stop", tray.Id) +func (d *DockerProvider) CleanTray(ctx context.Context, tray *trays.Tray) error { + dockerCommand := exec.CommandContext(ctx, "docker", "container", "stop", tray.Id) dockerCommandOutput, err := dockerCommand.CombinedOutput() if err != nil { output := string(dockerCommandOutput) diff --git a/src/lib/trays/providers/gceProvider.go b/src/lib/trays/providers/gceProvider.go index 97601b4..2f5e65b 100644 --- a/src/lib/trays/providers/gceProvider.go +++ b/src/lib/trays/providers/gceProvider.go @@ -24,7 +24,7 @@ type GceProvider struct { logger *logrus.Entry } -func NewGceProvider(name string, providerConfig config.ProviderConfig) *GceProvider { +func NewGceProvider(name string, providerConfig config.ProviderConfig) (*GceProvider, error) { provider := &GceProvider{ Name: name, providerConfig: providerConfig, @@ -33,11 +33,11 @@ func NewGceProvider(name string, providerConfig config.ProviderConfig) *GceProvi client, err := provider.createInstancesClient() if err != nil { - return nil + return nil, fmt.Errorf("create GCE instances client: %w", err) } provider.instanceClient = client - return provider + return provider, nil } func (g *GceProvider) GetProviderName() string { @@ -51,9 +51,7 @@ func (g *GceProvider) Close() error { return nil } -func (g *GceProvider) RunTray(tray *trays.Tray) error { - ctx := context.Background() - +func (g *GceProvider) RunTray(ctx context.Context, tray *trays.Tray) error { trayConfig, ok := tray.TrayConfig().(config.GoogleTrayConfig) if !ok { return fmt.Errorf("unexpected tray config type for gce provider, tray %s", tray.Id) @@ -108,16 +106,11 @@ func (g *GceProvider) RunTray(tray *trays.Tray) error { return nil } -func (g *GceProvider) CleanTray(tray *trays.Tray) error { - client, err := g.createInstancesClient() - if err != nil { - return err - } - +func (g *GceProvider) CleanTray(ctx context.Context, tray *trays.Tray) error { zone := tray.ProviderData["zone"] project := g.providerConfig.Get("project") - _, err = client.Delete(context.Background(), &computepb.DeleteInstanceRequest{ + _, err := g.instanceClient.Delete(ctx, &computepb.DeleteInstanceRequest{ Instance: tray.Id, Project: project, Zone: zone, @@ -139,30 +132,12 @@ func (g *GceProvider) CleanTray(tray *trays.Tray) error { } func (g *GceProvider) createInstancesClient() (*compute.InstancesClient, error) { - - if g.instanceClient != nil { - return g.instanceClient, nil - } - ctx := context.Background() - var ( - instancesClient *compute.InstancesClient - err error - ) - if credFile := g.providerConfig.Get("credentialsFile"); credFile != "" { - instancesClient, err = compute.NewInstancesRESTClient(ctx, option.WithCredentialsFile(g.providerConfig.Get("credentialsFile"))) - } else { - instancesClient, err = compute.NewInstancesRESTClient(ctx) + return compute.NewInstancesRESTClient(ctx, option.WithCredentialsFile(credFile)) } - - if err != nil { - return nil, err - } - - g.instanceClient = instancesClient - return instancesClient, nil + return compute.NewInstancesRESTClient(ctx) } func createGcpMetadata(fieldMaps ...map[string]string) *computepb.Metadata { diff --git a/src/lib/trays/providers/trayProvider.go b/src/lib/trays/providers/trayProvider.go index 1c540de..85133f9 100644 --- a/src/lib/trays/providers/trayProvider.go +++ b/src/lib/trays/providers/trayProvider.go @@ -2,16 +2,17 @@ package providers import ( "cattery/lib/trays" + "context" ) type TrayProvider interface { GetProviderName() string // RunTray spawns a new tray. - RunTray(tray *trays.Tray) error + RunTray(ctx context.Context, tray *trays.Tray) error // CleanTray deletes the tray with the given ID. - CleanTray(tray *trays.Tray) error + CleanTray(ctx context.Context, tray *trays.Tray) error } // TrayProviderFactory resolves providers by name or by tray. diff --git a/src/lib/trays/providers/trayProviderFactory.go b/src/lib/trays/providers/trayProviderFactory.go index 7e4cfa0..7449a4c 100644 --- a/src/lib/trays/providers/trayProviderFactory.go +++ b/src/lib/trays/providers/trayProviderFactory.go @@ -65,9 +65,11 @@ func GetProvider(providerName string) (TrayProvider, error) { result = p } case "google": - if p := NewGceProvider(providerName, provider); p != nil { - result = p + p, err := NewGceProvider(providerName, provider) + if err != nil { + return nil, err } + result = p default: return nil, errors.New("unknown provider type: " + provider["type"]) } diff --git a/src/server/handlers/agentHandler.go b/src/server/handlers/agentHandler.go index 7291b43..cb54acb 100644 --- a/src/server/handlers/agentHandler.go +++ b/src/server/handlers/agentHandler.go @@ -270,4 +270,6 @@ func (h *Handlers) AgentInterrupt(responseWriter http.ResponseWriter, r *http.Re http.Error(responseWriter, "Failed to request restart", http.StatusInternalServerError) return } + + responseWriter.WriteHeader(http.StatusOK) } diff --git a/src/server/handlers/agentHandler_test.go b/src/server/handlers/agentHandler_test.go index 8145396..c7f90f1 100644 --- a/src/server/handlers/agentHandler_test.go +++ b/src/server/handlers/agentHandler_test.go @@ -26,9 +26,9 @@ import ( type mockProvider struct{} -func (m *mockProvider) GetProviderName() string { return "mock" } -func (m *mockProvider) RunTray(_ *trays.Tray) error { return nil } -func (m *mockProvider) CleanTray(_ *trays.Tray) error { return nil } +func (m *mockProvider) GetProviderName() string { return "mock" } +func (m *mockProvider) RunTray(_ context.Context, _ *trays.Tray) error { return nil } +func (m *mockProvider) CleanTray(_ context.Context, _ *trays.Tray) error { return nil } type mockProviderFactory struct{} diff --git a/src/server/handlers/integration_test.go b/src/server/handlers/integration_test.go index 486bca9..c038099 100644 --- a/src/server/handlers/integration_test.go +++ b/src/server/handlers/integration_test.go @@ -52,9 +52,9 @@ var _ scaleSetClient.JitConfigGenerator = (*mockJitConfigGenerator)(nil) // mockProviderFactory implements providers.TrayProviderFactory type integrationMockProvider struct{} -func (m *integrationMockProvider) GetProviderName() string { return "mock" } -func (m *integrationMockProvider) RunTray(_ *trays.Tray) error { return nil } -func (m *integrationMockProvider) CleanTray(_ *trays.Tray) error { return nil } +func (m *integrationMockProvider) GetProviderName() string { return "mock" } +func (m *integrationMockProvider) RunTray(_ context.Context, _ *trays.Tray) error { return nil } +func (m *integrationMockProvider) CleanTray(_ context.Context, _ *trays.Tray) error { return nil } type integrationMockProviderFactory struct{}