Skip to content
34 changes: 32 additions & 2 deletions mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,32 @@ import (
// regex for GCCGO functions
var gccgoRE = regexp.MustCompile(`\.pN\d+_`)

// safeFormatArg is implemented by argumentMatcher to provide its own safe formatting.
// It avoids traversing reference types that may be concurrently modified.
type safeFormatArg interface {
SafeFormatArg() string
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No.

We don't need to invent a new interface.

As we format the value with %v, users can already implement fmt.Formatter or fmt.Stringer.


// formatArg returns a safe string representation of v for use in Diff output.
// If v implements safeFormatArg that is used; otherwise pointer/map/slice/chan
// types are formatted with %%p (address only) to avoid data races when other
// goroutines concurrently modify them.
func formatArg(v interface{}) string {
if v == nil {
return "<nil>"
}
if sf, ok := v.(safeFormatArg); ok {
return sf.SafeFormatArg()
}
rv := reflect.ValueOf(v)
kind := rv.Kind()
switch kind {
case reflect.Map, reflect.Ptr:
return fmt.Sprintf("(%[1]T=%[1]p)", v)
}
return fmt.Sprintf("(%[1]T=%[1]v)", v)
}

// TestingT is an interface wrapper around *testing.T
type TestingT interface {
Logf(format string, args ...interface{})
Expand Down Expand Up @@ -907,6 +933,10 @@ func (f argumentMatcher) String() string {
return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).String())
}

func (f argumentMatcher) SafeFormatArg() string {
return f.String()
}

// MatchedBy can be used to match a mock call based on only certain properties
// from a complex struct or some calculation. It takes a function that will be
// evaluated with the called argument and will return true when there's a match
Expand Down Expand Up @@ -977,15 +1007,15 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
actualFmt = "(Missing)"
} else {
actual = objects[i]
actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
actualFmt = formatArg(actual)
}

if len(args) <= i {
expected = "(Missing)"
expectedFmt = "(Missing)"
} else {
expected = args[i]
expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
expectedFmt = formatArg(expected)
}

if matcher, ok := expected.(argumentMatcher); ok {
Expand Down