Commit f3aac71f authored by zsfelfoldi's avatar zsfelfoldi

rpc/v2: optionally passing context argument to rpc v2 api methods

parent fa187a36
...@@ -101,6 +101,10 @@ ...@@ -101,6 +101,10 @@
"ImportPath": "golang.org/x/crypto/scrypt", "ImportPath": "golang.org/x/crypto/scrypt",
"Rev": "4ed45ec682102c643324fae5dff8dab085b6c300" "Rev": "4ed45ec682102c643324fae5dff8dab085b6c300"
}, },
{
"ImportPath": "golang.org/x/net/context",
"Rev": "e0403b4e005737430c05a57aac078479844f919c"
},
{ {
"ImportPath": "golang.org/x/net/html", "ImportPath": "golang.org/x/net/html",
"Rev": "e0403b4e005737430c05a57aac078479844f919c" "Rev": "e0403b4e005737430c05a57aac078479844f919c"
......
This diff is collapsed.
This diff is collapsed.
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context_test
import (
"fmt"
"time"
"golang.org/x/net/context"
)
func ExampleWithTimeout() {
// Pass a context with a timeout to tell a blocking function that it
// should abandon its work after the timeout elapses.
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
select {
case <-time.After(200 * time.Millisecond):
fmt.Println("overslept")
case <-ctx.Done():
fmt.Println(ctx.Err()) // prints "context deadline exceeded"
}
// Output:
// context deadline exceeded
}
...@@ -25,6 +25,7 @@ import ( ...@@ -25,6 +25,7 @@ import (
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/logger/glog"
"golang.org/x/net/context"
) )
// NewServer will create a new server instance with no registered handlers. // NewServer will create a new server instance with no registered handlers.
...@@ -120,6 +121,9 @@ func (s *Server) ServeCodec(codec ServerCodec) { ...@@ -120,6 +121,9 @@ func (s *Server) ServeCodec(codec ServerCodec) {
codec.Close() codec.Close()
}() }()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for { for {
reqs, batch, err := s.readRequest(codec) reqs, batch, err := s.readRequest(codec)
if err != nil { if err != nil {
...@@ -129,9 +133,9 @@ func (s *Server) ServeCodec(codec ServerCodec) { ...@@ -129,9 +133,9 @@ func (s *Server) ServeCodec(codec ServerCodec) {
} }
if batch { if batch {
go s.execBatch(codec, reqs) go s.execBatch(ctx, codec, reqs)
} else { } else {
go s.exec(codec, reqs[0]) go s.exec(ctx, codec, reqs[0])
} }
} }
} }
...@@ -220,7 +224,7 @@ func (s *Server) unsubscribe(subid string) bool { ...@@ -220,7 +224,7 @@ func (s *Server) unsubscribe(subid string) bool {
} }
// handle executes a request and returns the response from the callback. // handle executes a request and returns the response from the callback.
func (s *Server) handle(codec ServerCodec, req *serverRequest) interface{} { func (s *Server) handle(ctx context.Context, codec ServerCodec, req *serverRequest) interface{} {
if req.err != nil { if req.err != nil {
return codec.CreateErrorResponse(&req.id, req.err) return codec.CreateErrorResponse(&req.id, req.err)
} }
...@@ -255,6 +259,9 @@ func (s *Server) handle(codec ServerCodec, req *serverRequest) interface{} { ...@@ -255,6 +259,9 @@ func (s *Server) handle(codec ServerCodec, req *serverRequest) interface{} {
} }
arguments := []reflect.Value{req.callb.rcvr} arguments := []reflect.Value{req.callb.rcvr}
if req.callb.hasCtx {
arguments = append(arguments, reflect.ValueOf(ctx))
}
if len(req.args) > 0 { if len(req.args) > 0 {
arguments = append(arguments, req.args...) arguments = append(arguments, req.args...)
} }
...@@ -277,12 +284,12 @@ func (s *Server) handle(codec ServerCodec, req *serverRequest) interface{} { ...@@ -277,12 +284,12 @@ func (s *Server) handle(codec ServerCodec, req *serverRequest) interface{} {
} }
// exec executes the given request and writes the result back using the codec. // exec executes the given request and writes the result back using the codec.
func (s *Server) exec(codec ServerCodec, req *serverRequest) { func (s *Server) exec(ctx context.Context, codec ServerCodec, req *serverRequest) {
var response interface{} var response interface{}
if req.err != nil { if req.err != nil {
response = codec.CreateErrorResponse(&req.id, req.err) response = codec.CreateErrorResponse(&req.id, req.err)
} else { } else {
response = s.handle(codec, req) response = s.handle(ctx, codec, req)
} }
if err := codec.Write(response); err != nil { if err := codec.Write(response); err != nil {
...@@ -293,13 +300,13 @@ func (s *Server) exec(codec ServerCodec, req *serverRequest) { ...@@ -293,13 +300,13 @@ func (s *Server) exec(codec ServerCodec, req *serverRequest) {
// execBatch executes the given requests and writes the result back using the codec. It will only write the response // execBatch executes the given requests and writes the result back using the codec. It will only write the response
// back when the last request is processed. // back when the last request is processed.
func (s *Server) execBatch(codec ServerCodec, requests []*serverRequest) { func (s *Server) execBatch(ctx context.Context, codec ServerCodec, requests []*serverRequest) {
responses := make([]interface{}, len(requests)) responses := make([]interface{}, len(requests))
for i, req := range requests { for i, req := range requests {
if req.err != nil { if req.err != nil {
responses[i] = codec.CreateErrorResponse(&req.id, req.err) responses[i] = codec.CreateErrorResponse(&req.id, req.err)
} else { } else {
responses[i] = s.handle(codec, req) responses[i] = s.handle(ctx, codec, req)
} }
} }
......
...@@ -6,6 +6,8 @@ import ( ...@@ -6,6 +6,8 @@ import (
"reflect" "reflect"
"testing" "testing"
"time" "time"
"golang.org/x/net/context"
) )
type Service struct{} type Service struct{}
...@@ -27,6 +29,10 @@ func (s *Service) Echo(str string, i int, args *Args) Result { ...@@ -27,6 +29,10 @@ func (s *Service) Echo(str string, i int, args *Args) Result {
return Result{str, i, args} return Result{str, i, args}
} }
func (s *Service) EchoWithCtx(ctx context.Context, str string, i int, args *Args) Result {
return Result{str, i, args}
}
func (s *Service) Rets() (string, error) { func (s *Service) Rets() (string, error) {
return "", nil return "", nil
} }
...@@ -64,8 +70,8 @@ func TestServerRegisterName(t *testing.T) { ...@@ -64,8 +70,8 @@ func TestServerRegisterName(t *testing.T) {
t.Fatalf("Expected service calc to be registered") t.Fatalf("Expected service calc to be registered")
} }
if len(svc.callbacks) != 3 { if len(svc.callbacks) != 4 {
t.Errorf("Expected 3 callbacks for service 'calc', got %d", len(svc.callbacks)) t.Errorf("Expected 4 callbacks for service 'calc', got %d", len(svc.callbacks))
} }
if len(svc.subscriptions) != 1 { if len(svc.subscriptions) != 1 {
...@@ -217,3 +223,33 @@ func TestServerMethodExecution(t *testing.T) { ...@@ -217,3 +223,33 @@ func TestServerMethodExecution(t *testing.T) {
t.Fatalf("expected %s, got %s\n", expected, codec.output) t.Fatalf("expected %s, got %s\n", expected, codec.output)
} }
} }
func TestServerMethodWithCtx(t *testing.T) {
server := NewServer()
service := new(Service)
if err := server.RegisterName("test", service); err != nil {
t.Fatalf("%v", err)
}
id := int64(12345)
req := jsonRequest{
Method: "echoWithCtx",
Version: "2.0",
Id: &id,
}
args := []interface{}{"string arg", 1122, &Args{"qwerty"}}
req.Payload, _ = json.Marshal(&args)
input, _ := json.Marshal(&req)
codec := &ServerTestCodec{input: input, closer: make(chan interface{})}
go server.ServeCodec(codec)
<-codec.closer
expected := `{"jsonrpc":"2.0","id":12345,"result":{"String":"string arg","Int":1122,"Args":{"S":"qwerty"}}}`
if expected != codec.output {
t.Fatalf("expected %s, got %s\n", expected, codec.output)
}
}
...@@ -22,7 +22,6 @@ import ( ...@@ -22,7 +22,6 @@ import (
"math/big" "math/big"
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
...@@ -41,6 +40,7 @@ type callback struct { ...@@ -41,6 +40,7 @@ type callback struct {
rcvr reflect.Value // receiver of method rcvr reflect.Value // receiver of method
method reflect.Method // callback method reflect.Method // callback
argTypes []reflect.Type // input argument types argTypes []reflect.Type // input argument types
hasCtx bool // method's first argument is a context (not included in argTypes)
errPos int // err return idx, of -1 when method cannot return error errPos int // err return idx, of -1 when method cannot return error
isSubscribe bool // indication if the callback is a subscription isSubscribe bool // indication if the callback is a subscription
} }
......
...@@ -24,6 +24,8 @@ import ( ...@@ -24,6 +24,8 @@ import (
"reflect" "reflect"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
"golang.org/x/net/context"
) )
// Is this an exported - upper case - name? // Is this an exported - upper case - name?
...@@ -107,6 +109,8 @@ func isBlockNumber(t reflect.Type) bool { ...@@ -107,6 +109,8 @@ func isBlockNumber(t reflect.Type) bool {
return t == blockNumberType return t == blockNumberType
} }
var contextType = reflect.TypeOf(new(context.Context)).Elem()
// suitableCallbacks iterates over the methods of the given type. It will determine if a method satisfies the criteria // suitableCallbacks iterates over the methods of the given type. It will determine if a method satisfies the criteria
// for a RPC callback or a subscription callback and adds it to the collection of callbacks or subscriptions. See server // for a RPC callback or a subscription callback and adds it to the collection of callbacks or subscriptions. See server
// documentation for a summary of these criteria. // documentation for a summary of these criteria.
...@@ -129,12 +133,19 @@ METHODS: ...@@ -129,12 +133,19 @@ METHODS:
h.method = method h.method = method
h.errPos = -1 h.errPos = -1
firstArg := 1
numIn := mtype.NumIn()
if numIn >= 2 && mtype.In(1) == contextType {
h.hasCtx = true
firstArg = 2
}
if h.isSubscribe { if h.isSubscribe {
h.argTypes = make([]reflect.Type, mtype.NumIn()-1) // skip rcvr type h.argTypes = make([]reflect.Type, numIn-firstArg) // skip rcvr type
for i := 1; i < mtype.NumIn(); i++ { for i := firstArg; i < numIn; i++ {
argType := mtype.In(i) argType := mtype.In(i)
if isExportedOrBuiltinType(argType) { if isExportedOrBuiltinType(argType) {
h.argTypes[i-1] = argType h.argTypes[i-firstArg] = argType
} else { } else {
continue METHODS continue METHODS
} }
...@@ -144,17 +155,15 @@ METHODS: ...@@ -144,17 +155,15 @@ METHODS:
continue METHODS continue METHODS
} }
numIn := mtype.NumIn()
// determine method arguments, ignore first arg since it's the receiver type // determine method arguments, ignore first arg since it's the receiver type
// Arguments must be exported or builtin types // Arguments must be exported or builtin types
h.argTypes = make([]reflect.Type, numIn-1) h.argTypes = make([]reflect.Type, numIn-firstArg)
for i := 1; i < numIn; i++ { for i := firstArg; i < numIn; i++ {
argType := mtype.In(i) argType := mtype.In(i)
if !isExportedOrBuiltinType(argType) { if !isExportedOrBuiltinType(argType) {
continue METHODS continue METHODS
} }
h.argTypes[i-1] = argType h.argTypes[i-firstArg] = argType
} }
// check that all returned values are exported or builtin types // check that all returned values are exported or builtin types
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment