summaryrefslogtreecommitdiff
path: root/vendor/github.com/golang/mock/gomock/controller.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/golang/mock/gomock/controller.go')
-rw-r--r--vendor/github.com/golang/mock/gomock/controller.go144
1 files changed, 89 insertions, 55 deletions
diff --git a/vendor/github.com/golang/mock/gomock/controller.go b/vendor/github.com/golang/mock/gomock/controller.go
index 6bff78d..a7b7918 100644
--- a/vendor/github.com/golang/mock/gomock/controller.go
+++ b/vendor/github.com/golang/mock/gomock/controller.go
@@ -57,7 +57,9 @@ package gomock
import (
"fmt"
+ "golang.org/x/net/context"
"reflect"
+ "runtime"
"sync"
)
@@ -74,17 +76,40 @@ type TestReporter interface {
type Controller struct {
mu sync.Mutex
t TestReporter
- expectedCalls callSet
+ expectedCalls *callSet
+ finished bool
}
func NewController(t TestReporter) *Controller {
return &Controller{
t: t,
- expectedCalls: make(callSet),
+ expectedCalls: newCallSet(),
}
}
+type cancelReporter struct {
+ t TestReporter
+ cancel func()
+}
+
+func (r *cancelReporter) Errorf(format string, args ...interface{}) { r.t.Errorf(format, args...) }
+func (r *cancelReporter) Fatalf(format string, args ...interface{}) {
+ defer r.cancel()
+ r.t.Fatalf(format, args...)
+}
+
+// WithContext returns a new Controller and a Context, which is cancelled on any
+// fatal failure.
+func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) {
+ ctx, cancel := context.WithCancel(ctx)
+ return NewController(&cancelReporter{t, cancel}), ctx
+}
+
func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...interface{}) *Call {
+ if h, ok := ctrl.t.(testHelper); ok {
+ h.Helper()
+ }
+
recv := reflect.ValueOf(receiver)
for i := 0; i < recv.Type().NumMethod(); i++ {
if recv.Type().Method(i).Name == method {
@@ -92,73 +117,77 @@ func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...
}
}
ctrl.t.Fatalf("gomock: failed finding method %s on %T", method, receiver)
- // In case t.Fatalf does not panic.
- panic(fmt.Sprintf("gomock: failed finding method %s on %T", method, receiver))
+ panic("unreachable")
}
func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
- // TODO: check arity, types.
- margs := make([]Matcher, len(args))
- for i, arg := range args {
- if m, ok := arg.(Matcher); ok {
- margs[i] = m
- } else if arg == nil {
- // Handle nil specially so that passing a nil interface value
- // will match the typed nils of concrete args.
- margs[i] = Nil()
- } else {
- margs[i] = Eq(arg)
- }
+ if h, ok := ctrl.t.(testHelper); ok {
+ h.Helper()
}
+ call := newCall(ctrl.t, receiver, method, methodType, args...)
+
ctrl.mu.Lock()
defer ctrl.mu.Unlock()
-
- call := &Call{t: ctrl.t, receiver: receiver, method: method, methodType: methodType, args: margs, minCalls: 1, maxCalls: 1}
-
ctrl.expectedCalls.Add(call)
+
return call
}
func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {
- ctrl.mu.Lock()
- defer ctrl.mu.Unlock()
-
- expected := ctrl.expectedCalls.FindMatch(receiver, method, args)
- if expected == nil {
- ctrl.t.Fatalf("no matching expected call: %T.%v(%v)", receiver, method, args)
+ if h, ok := ctrl.t.(testHelper); ok {
+ h.Helper()
}
- // Two things happen here:
- // * the matching call no longer needs to check prerequite calls,
- // * and the prerequite calls are no longer expected, so remove them.
- preReqCalls := expected.dropPrereqs()
- for _, preReqCall := range preReqCalls {
- ctrl.expectedCalls.Remove(preReqCall)
- }
+ // Nest this code so we can use defer to make sure the lock is released.
+ actions := func() []func([]interface{}) []interface{} {
+ ctrl.mu.Lock()
+ defer ctrl.mu.Unlock()
- rets, action := expected.call(args)
- if expected.exhausted() {
- ctrl.expectedCalls.Remove(expected)
- }
+ expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
+ if err != nil {
+ origin := callerInfo(2)
+ ctrl.t.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
+ }
- // Don't hold the lock while doing the call's action (if any)
- // so that actions may execute concurrently.
- // We use the deferred Unlock to capture any panics that happen above;
- // here we add a deferred Lock to balance it.
- ctrl.mu.Unlock()
- defer ctrl.mu.Lock()
- if action != nil {
- action()
+ // Two things happen here:
+ // * the matching call no longer needs to check prerequite calls,
+ // * and the prerequite calls are no longer expected, so remove them.
+ preReqCalls := expected.dropPrereqs()
+ for _, preReqCall := range preReqCalls {
+ ctrl.expectedCalls.Remove(preReqCall)
+ }
+
+ actions := expected.call(args)
+ if expected.exhausted() {
+ ctrl.expectedCalls.Remove(expected)
+ }
+ return actions
+ }()
+
+ var rets []interface{}
+ for _, action := range actions {
+ if r := action(args); r != nil {
+ rets = r
+ }
}
return rets
}
func (ctrl *Controller) Finish() {
+ if h, ok := ctrl.t.(testHelper); ok {
+ h.Helper()
+ }
+
ctrl.mu.Lock()
defer ctrl.mu.Unlock()
+ if ctrl.finished {
+ ctrl.t.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.")
+ }
+ ctrl.finished = true
+
// If we're currently panicking, probably because this is a deferred call,
// pass through the panic.
if err := recover(); err != nil {
@@ -166,18 +195,23 @@ func (ctrl *Controller) Finish() {
}
// Check that all remaining expected calls are satisfied.
- failures := false
- for _, methodMap := range ctrl.expectedCalls {
- for _, calls := range methodMap {
- for _, call := range calls {
- if !call.satisfied() {
- ctrl.t.Errorf("missing call(s) to %v", call)
- failures = true
- }
- }
- }
+ failures := ctrl.expectedCalls.Failures()
+ for _, call := range failures {
+ ctrl.t.Errorf("missing call(s) to %v", call)
}
- if failures {
+ if len(failures) != 0 {
ctrl.t.Fatalf("aborting test due to missing call(s)")
}
}
+
+func callerInfo(skip int) string {
+ if _, file, line, ok := runtime.Caller(skip + 1); ok {
+ return fmt.Sprintf("%s:%d", file, line)
+ }
+ return "unknown file"
+}
+
+type testHelper interface {
+ TestReporter
+ Helper()
+}