Skip to content

Commit

Permalink
fix: send SIGTERM signal to --cmd instead of SIGKILL (#687)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Hesketh <[email protected]>
Co-authored-by: Adrian Hesketh <[email protected]>
  • Loading branch information
3 people authored Sep 2, 2024
1 parent c7c32aa commit 3ac3c9d
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 11 deletions.
108 changes: 108 additions & 0 deletions cmd/templ/generatecmd/run/run_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package run_test

import (
"context"
"embed"
"io"
"net/http"
"os"
"path/filepath"
"syscall"
"testing"
"time"

"github.com/a-h/templ/cmd/templ/generatecmd/run"
)

//go:embed testprogram/*
var testprogram embed.FS

func TestGoRun(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
}

// Copy testprogram to a temporary directory.
dir, err := os.MkdirTemp("", "testprogram")
if err != nil {
t.Fatalf("failed to make test dir: %v", err)
}
files, err := testprogram.ReadDir("testprogram")
if err != nil {
t.Fatalf("failed to read embedded dir: %v", err)
}
for _, file := range files {
srcFileName := "testprogram/" + file.Name()
srcData, err := testprogram.ReadFile(srcFileName)
if err != nil {
t.Fatalf("failed to read src file %q: %v", srcFileName, err)
}
tgtFileName := filepath.Join(dir, file.Name())
tgtFile, err := os.Create(tgtFileName)
if err != nil {
t.Fatalf("failed to create tgt file %q: %v", tgtFileName, err)
}
defer tgtFile.Close()
if _, err := tgtFile.Write(srcData); err != nil {
t.Fatalf("failed to write to tgt file %q: %v", tgtFileName, err)
}
}
// Rename the go.mod.embed file to go.mod.
if err := os.Rename(filepath.Join(dir, "go.mod.embed"), filepath.Join(dir, "go.mod")); err != nil {
t.Fatalf("failed to rename go.mod.embed: %v", err)
}

tests := []struct {
name string
cmd string
}{
{
name: "Well behaved programs get shut down",
cmd: "go run .",
},
{
name: "Badly behaved programs get shut down",
cmd: "go run . -badly-behaved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cmd, err := run.Run(ctx, dir, tt.cmd)
if err != nil {
t.Fatalf("failed to run program: %v", err)
}

time.Sleep(1 * time.Second)

pid := cmd.Process.Pid

if err := run.KillAll(); err != nil {
t.Fatalf("failed to kill all: %v", err)
}

// Check the parent process is no longer running.
if err := cmd.Process.Signal(os.Signal(syscall.Signal(0))); err == nil {
t.Fatalf("process %d is still running", pid)
}
// Check that the child was stopped.
body, err := readResponse("http://localhost:7777")
if err == nil {
t.Fatalf("child process is still running: %s", body)
}
})
}
}

func readResponse(url string) (body string, err error) {
resp, err := http.Get(url)
if err != nil {
return body, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return body, err
}
return string(b), nil
}
46 changes: 35 additions & 11 deletions cmd/templ/generatecmd/run/run_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,63 @@ package run

import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strings"
"sync"
"syscall"
"time"
)

var m = &sync.Mutex{}
var running = map[string]*exec.Cmd{}
var (
m = &sync.Mutex{}
running = map[string]*exec.Cmd{}
)

func KillAll() (err error) {
m.Lock()
defer m.Unlock()
var errs []error
for _, cmd := range running {
err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
if err != nil {
return err
if err := kill(cmd); err != nil {
errs = append(errs, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err))
}
}
running = map[string]*exec.Cmd{}
return
return errors.Join(errs...)
}

func kill(cmd *exec.Cmd) (err error) {
errs := make([]error, 4)
errs[0] = ignoreExited(cmd.Process.Signal(syscall.SIGINT))
errs[1] = ignoreExited(cmd.Process.Signal(syscall.SIGTERM))
errs[2] = ignoreExited(cmd.Wait())
errs[3] = ignoreExited(syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL))
return errors.Join(errs...)
}

func Stop(cmd *exec.Cmd) (err error) {
return syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
func ignoreExited(err error) error {
if errors.Is(err, syscall.ESRCH) {
return nil
}
// Ignore *exec.ExitError
if _, ok := err.(*exec.ExitError); ok {
return nil
}
return err
}

func Run(ctx context.Context, workingDir, input string) (cmd *exec.Cmd, err error) {
m.Lock()
defer m.Unlock()
cmd, ok := running[input]
if ok {
if err = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL); err != nil {
return cmd, err
if err := kill(cmd); err != nil {
return cmd, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err)
}

delete(running, input)
}
parts := strings.Fields(input)
Expand All @@ -48,7 +70,9 @@ func Run(ctx context.Context, workingDir, input string) (cmd *exec.Cmd, err erro
args = append(args, parts[1:]...)
}

cmd = exec.Command(executable, args...)
cmd = exec.CommandContext(ctx, executable, args...)
// Wait for the process to finish gracefully before termination.
cmd.WaitDelay = time.Second * 3
cmd.Env = os.Environ()
cmd.Dir = workingDir
cmd.Stdout = os.Stdout
Expand Down
3 changes: 3 additions & 0 deletions cmd/templ/generatecmd/run/testprogram/go.mod.embed
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module testprogram

go 1.22.6
63 changes: 63 additions & 0 deletions cmd/templ/generatecmd/run/testprogram/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package main

import (
"flag"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)

// This is a test program. It is used only to test the behaviour of the run package.
// The run package is supposed to be able to run and stop programs. Those programs may start
// child processes, which should also be stopped when the parent program is stopped.

// For example, running `go run .` will compile an executable and run it.

// So, this program does nothing. It just waits for a signal to stop.

// In "Well behaved" mode, the program will stop when it receives a signal.
// In "Badly behaved" mode, the program will ignore the signal and continue running.

// The run package should be able to stop the program in both cases.

var badlyBehavedFlag = flag.Bool("badly-behaved", false, "If set, the program will ignore the stop signal and continue running.")

func main() {
flag.Parse()

mode := "Well behaved"
if *badlyBehavedFlag {
mode = "Badly behaved"
}
fmt.Printf("%s process %d started.\n", mode, os.Getpid())

// Start a web server on a known port so that we can check that this process is
// not running, when it's been started as a child process, and we don't know
// its pid.
go func() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%d", os.Getpid())
})
err := http.ListenAndServe("127.0.0.1:7777", nil)
if err != nil {
fmt.Printf("Error running web server: %v\n", err)
}
}()

sigs := make(chan os.Signal, 1)
if !*badlyBehavedFlag {
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
}
for {
select {
case <-sigs:
fmt.Printf("Process %d received signal. Stopping.\n", os.Getpid())
return
case <-time.After(1 * time.Second):
fmt.Printf("Process %d still running...\n", os.Getpid())
}
}
}

0 comments on commit 3ac3c9d

Please sign in to comment.