Commit 893c9e02 authored by Mitchell Hashimoto's avatar Mitchell Hashimoto

packer/plugin: Count number of interrupts atomically

parent 71379bc8
......@@ -19,8 +19,14 @@ import (
"os/signal"
"runtime"
"strconv"
"sync/atomic"
)
// This is a count of the number of interrupts the process has received.
// This is updated with sync/atomic whenever a SIGINT is received and can
// be checked by the plugin safely to take action.
var Interrupts int32 = 0
const MagicCookieKey = "PACKER_PLUGIN_MAGIC_COOKIE"
const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d6991ca9872b2"
......@@ -86,17 +92,18 @@ func serve(server *rpc.Server) (err error) {
return
}
// Registers a signal handler to "swallow" interrupts so that the
// Registers a signal handler to swallow and count interrupts so that the
// plugin isn't killed. The main host Packer process is responsible
// for killing the plugins when interrupted.
func swallowInterrupts() {
func countInterrupts() {
ch := make(chan os.Signal, 1)
signal.Notify(ch, os.Interrupt)
go func() {
for {
<-ch
log.Println("Received interrupt signal. Ignoring.")
newCount := atomic.AddInt32(&Interrupts, 1)
log.Printf("Received interrupt signal (count: %d). Ignoring.", newCount)
}
}()
}
......@@ -108,7 +115,7 @@ func ServeBuilder(builder packer.Builder) {
server := rpc.NewServer()
packrpc.RegisterBuilder(server, builder)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
......@@ -122,7 +129,7 @@ func ServeCommand(command packer.Command) {
server := rpc.NewServer()
packrpc.RegisterCommand(server, command)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
......@@ -136,7 +143,7 @@ func ServeHook(hook packer.Hook) {
server := rpc.NewServer()
packrpc.RegisterHook(server, hook)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
......@@ -150,7 +157,7 @@ func ServePostProcessor(p packer.PostProcessor) {
server := rpc.NewServer()
packrpc.RegisterPostProcessor(server, p)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
......@@ -164,7 +171,7 @@ func ServeProvisioner(p packer.Provisioner) {
server := rpc.NewServer()
packrpc.RegisterProvisioner(server, p)
swallowInterrupts()
countInterrupts()
if err := serve(server); err != nil {
log.Printf("ERROR: %s", err)
os.Exit(1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment