Testing concurrent code in Golang


I wrote a package which tracks certain hashtags on twitter and part of the package is adding and removing hashtag from list of hashtags. Plan is that this can be done from different goroutines hence tracking/untracking is done by putting hashtag in channels. It looks like this:

var (
    trackChan   = make(chan string)
    untrackChan = make(chan string)
)

func Track(hs ...string) {
    for _, h := range hs {
        trackChan <- h
    }
}

func Untrack(hs ...string) {
    for _, h := range hs {
        untrackChan <- h
    }
}

var hashtags []string

func init() {
    go func() {
        for {
            select {
            case h := <-trackChan:
                if !sliceContains(h, hashtags) {
                    hashtags = append(hashtags, h)
                }
            case h := <-untrackChan:
                if sliceContains(h, hashtags) {
                    hashtags = sliceRemove(h, hashtags)
                }
            }
        }
    }()
}

sliceContains and sliceRemove are (probably) self-explanatory helper functions.

Since I write tests for my code I wrote three tests for this part: for track, untrack and combined scenario. At first you may write tests like this:

func TestTrack(t *testing.T) {
    h := "name"
    Track(h)
    assert.Contains(t, hashtags, h)
}

func TestUntrack(t *testing.T) {
    name := "name"
    hashtags = []string{name, "surname"}
    Untrack(name)
    assert.NotContains(t, hashtags, name)
}

func TestTrackUntrack(t *testing.T) {
    h := "name"

    Track(h)
    Untrack(h)

    assert.NotContains(t, hashtags, h)
}

But tests are not working! The main problem here is that we are dealing with concurrent code where we don't know when commands will finish with execution (concretely when slice hashtags will be updated).

Testing concurrent code in Go turns out to be easier after watching Testing Techniques by Andrew Gerrand.

Main idea here is inserting empty function (no-op) after interesting code which acts as a hook for testing purposes. In testing this function will push a value into the channel which was created in test. We receive this value in test code where we want to ensure that some logic is executed (basically blocking the code till something is executed).

So the solution to the above problem was defining empty function (e.g. doneFunc) and calling it after update to hashtags slice. And than in test creating a channel (e.g. doneChan) and redefining doneFunc so that a value is pushed to doneChan. ((It doesn't matter what value. Boolean is perfectly OK, but even better is an empty struct since it takes 0 bytes - boolean takes 1 byte.)). As a good practice at the end of the test we restore doneFunc to empty function.

Updated code:

var (
    trackChan   = make(chan string)
    untrackChan = make(chan string)
    doneFunc    = func() {}
)

func Track(hs ...string) {
    for _, h := range hs {
        trackChan <- h
    }
}

func Untrack(hs ...string) {
    for _, h := range hs {
        untrackChan <- h
    }
}

var hashtags []string

func init() {
    go func() {
        for {
            select {
            case h := <-trackChan:
                if !sliceContains(h, hashtags) {
                    hashtags = append(hashtags, h)
                    doneFunc()
                }
            case h := <-untrackChan:
                if sliceContains(h, hashtags) {
                    hashtags = sliceRemove(h, hashtags)
                    doneFunc()
                }
            }
        }
    }()
}

Updated tests:

func TestTrack(t *testing.T) {
    doneChan := make(chan struct{}, 1)
    doneFunc = func() { doneChan <- struct{}{} }
    defer func() {
        doneFunc = func() {}
    }()

    h := "name"

    Track(h)
    <-doneChan

    assert.Contains(t, hashtags, h)
}

func TestUntrack(t *testing.T) {
    doneChan := make(chan struct{}, 1)
    doneFunc = func() { doneChan <- struct{}{} }
    defer func() {
        doneFunc = func() {}
    }()

    name := "name"
    hashtags = []string{name, "surname"}

    Untrack(name)
    <-doneChan

    assert.NotContains(t, hashtags, name)
}

func TestTrackUntrack(t *testing.T) {
    doneChan := make(chan struct{}, 1)
    doneFunc = func() {
        doneChan <- struct{}{}
    }
    defer func() {
        doneFunc = func() {}
    }()
    h := "name"

    Track(h)
    <-doneChan
    Untrack(h)
    <-doneChan

    assert.NotContains(t, hashtags, h)
}