Commit 5c251b69 authored by obscuren's avatar obscuren

Merge branch 'fjl-poc8-net-integration' into develop

parents 75d16403 bde3ff16
......@@ -14,45 +14,80 @@
You should have received a copy of the GNU General Public License
along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
*/
// Command bootnode runs a bootstrap node for the Discovery Protocol.
package main
import (
"crypto/elliptic"
"crypto/ecdsa"
"encoding/hex"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
)
var (
natType = flag.String("nat", "", "NAT traversal implementation")
pmpGateway = flag.String("gateway", "", "gateway address for NAT-PMP")
listenAddr = flag.String("addr", ":30301", "listen address")
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/nat"
)
func main() {
var (
listenAddr = flag.String("addr", ":30301", "listen address")
genKey = flag.String("genkey", "", "generate a node key and quit")
nodeKeyFile = flag.String("nodekey", "", "private key filename")
nodeKeyHex = flag.String("nodekeyhex", "", "private key as hex (for testing)")
natdesc = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
nodeKey *ecdsa.PrivateKey
err error
)
flag.Parse()
nat, err := p2p.ParseNAT(*natType, *pmpGateway)
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
if *genKey != "" {
writeKey(*genKey)
os.Exit(0)
}
natm, err := nat.Parse(*natdesc)
if err != nil {
log.Fatal("invalid nat:", err)
log.Fatalf("-nat: %v", err)
}
switch {
case *nodeKeyFile == "" && *nodeKeyHex == "":
log.Fatal("Use -nodekey or -nodekeyhex to specify a private key")
case *nodeKeyFile != "" && *nodeKeyHex != "":
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
case *nodeKeyFile != "":
if nodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
log.Fatalf("-nodekey: %v", err)
}
case *nodeKeyHex != "":
if nodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
log.Fatalf("-nodekeyhex: %v", err)
}
}
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.InfoLevel))
key, _ := crypto.GenerateKey()
marshaled := elliptic.Marshal(crypto.S256(), key.PublicKey.X, key.PublicKey.Y)
if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm); err != nil {
log.Fatal(err)
}
select {}
}
srv := p2p.Server{
MaxPeers: 100,
Identity: p2p.NewSimpleClientIdentity("Ethereum(G)", "0.1", "Peer Server Two", marshaled),
ListenAddr: *listenAddr,
NAT: nat,
NoDial: true,
func writeKey(target string) {
key, err := crypto.GenerateKey()
if err != nil {
log.Fatal("could not generate key: %v", err)
}
if err := srv.Start(); err != nil {
log.Fatal("could not start server:", err)
b := crypto.FromECDSA(key)
if target == "-" {
fmt.Println(hex.EncodeToString(b))
} else {
if err := ioutil.WriteFile(target, b, 0600); err != nil {
log.Fatal("write error: ", err)
}
}
select {}
}
......@@ -21,6 +21,7 @@
package main
import (
"crypto/ecdsa"
"flag"
"fmt"
"log"
......@@ -28,7 +29,9 @@ import (
"os/user"
"path"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/vm"
)
......@@ -42,14 +45,14 @@ var (
StartWebSockets bool
RpcPort int
WsPort int
NatType string
PMPGateway string
OutboundPort string
ShowGenesis bool
AddPeer string
MaxPeer int
GenAddr bool
SeedNode string
BootNodes string
NodeKey *ecdsa.PrivateKey
NAT nat.Interface
SecretFile string
ExportDir string
NonInteractive bool
......@@ -84,6 +87,7 @@ func defaultDataDir() string {
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
func Init() {
// TODO: move common flag processing to cmd/util
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
flag.PrintDefaults()
......@@ -93,18 +97,12 @@ func Init() {
flag.StringVar(&Identifier, "id", "", "Custom client identifier")
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
flag.StringVar(&NatType, "nat", "", "NAT support (UPNP|PMP) (none)")
flag.StringVar(&PMPGateway, "pmp", "", "Gateway IP for PMP")
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
flag.BoolVar(&StartRpc, "rpc", false, "start rpc server")
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
flag.StringVar(&SeedNode, "seednode", "poc-8.ethdev.com:30303", "ip:port of seed node to connect to. Set to blank for skip")
flag.BoolVar(&SHH, "shh", true, "whisper protocol (on)")
flag.BoolVar(&Dial, "dial", true, "dial out connections (on)")
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
......@@ -127,8 +125,38 @@ func Init() {
flag.BoolVar(&StartJsConsole, "js", false, "launches javascript console")
flag.BoolVar(&PrintVersion, "version", false, "prints version number")
// Network stuff
var (
nodeKeyFile = flag.String("nodekey", "", "network private key file")
nodeKeyHex = flag.String("nodekeyhex", "", "network private key (for testing)")
natstr = flag.String("nat", "any", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
)
flag.BoolVar(&Dial, "dial", true, "dial out connections (default on)")
flag.BoolVar(&SHH, "shh", true, "run whisper protocol (default on)")
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
flag.StringVar(&BootNodes, "bootnodes", "", "space-separated node URLs for discovery bootstrap")
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
flag.Parse()
var err error
if NAT, err = nat.Parse(*natstr); err != nil {
log.Fatalf("-nat: %v", err)
}
switch {
case *nodeKeyFile != "" && *nodeKeyHex != "":
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
case *nodeKeyFile != "":
if NodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
log.Fatalf("-nodekey: %v", err)
}
case *nodeKeyHex != "":
if NodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
log.Fatalf("-nodekeyhex: %v", err)
}
}
if VmType >= int(vm.MaxVmTy) {
log.Fatal("Invalid VM type ", VmType)
}
......
......@@ -31,6 +31,7 @@ import (
"github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/state"
)
......@@ -61,21 +62,19 @@ func main() {
utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
ethereum, err := eth.New(&eth.Config{
Name: ClientIdentifier,
Version: Version,
KeyStore: KeyStore,
DataDir: Datadir,
LogFile: LogFile,
LogLevel: LogLevel,
LogFormat: LogFormat,
Identifier: Identifier,
MaxPeers: MaxPeer,
Port: OutboundPort,
NATType: PMPGateway,
PMPGateway: PMPGateway,
KeyRing: KeyRing,
Shh: SHH,
Dial: Dial,
Name: p2p.MakeName(ClientIdentifier, Version),
KeyStore: KeyStore,
DataDir: Datadir,
LogFile: LogFile,
LogLevel: LogLevel,
MaxPeers: MaxPeer,
Port: OutboundPort,
NAT: NAT,
KeyRing: KeyRing,
Shh: SHH,
Dial: Dial,
BootNodes: BootNodes,
NodeKey: NodeKey,
})
if err != nil {
......@@ -135,7 +134,7 @@ func main() {
utils.StartWebSockets(ethereum, WsPort)
}
utils.StartEthereum(ethereum, SeedNode)
utils.StartEthereum(ethereum)
if StartJsConsole {
InitJsConsole(ethereum)
......
......@@ -79,6 +79,12 @@
contract.received({from: eth.coinbase}).changed(function() {
refresh();
});
var ev = contract.SingleTransact({})
ev.watch(function(log) {
someElement.innerHTML += "tnaheousnthaoeu";
});
eth.watch('chain').changed(function() {
refresh();
});
......
......@@ -32,18 +32,6 @@ Rectangle {
width: 500
}
Label {
text: "Client ID"
}
TextField {
text: gui.getCustomIdentifier()
width: 500
placeholderText: "Anonymous"
onTextChanged: {
gui.setCustomIdentifier(text)
}
}
TextArea {
objectName: "statsPane"
width: parent.width
......
......@@ -64,15 +64,6 @@ func (gui *Gui) Transact(recipient, value, gas, gasPrice, d string) (string, err
return gui.xeth.Transact(recipient, value, gas, gasPrice, data)
}
func (gui *Gui) SetCustomIdentifier(customIdentifier string) {
gui.clientIdentity.SetCustomIdentifier(customIdentifier)
gui.config.Save("id", customIdentifier)
}
func (gui *Gui) GetCustomIdentifier() string {
return gui.clientIdentity.GetCustomIdentifier()
}
// functions that allow Gui to implement interface guilogger.LogSystem
func (gui *Gui) SetLogLevel(level logger.LogLevel) {
gui.logLevel = level
......
......@@ -21,6 +21,7 @@
package main
import (
"crypto/ecdsa"
"flag"
"fmt"
"log"
......@@ -31,7 +32,9 @@ import (
"runtime"
"bitbucket.org/kardianos/osext"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/vm"
)
......@@ -39,19 +42,18 @@ var (
Identifier string
KeyRing string
KeyStore string
PMPGateway string
StartRpc bool
StartWebSockets bool
RpcPort int
WsPort int
UseUPnP bool
NatType string
OutboundPort string
ShowGenesis bool
AddPeer string
MaxPeer int
GenAddr bool
SeedNode string
BootNodes string
NodeKey *ecdsa.PrivateKey
NAT nat.Interface
SecretFile string
ExportDir string
NonInteractive bool
......@@ -99,6 +101,7 @@ func defaultDataDir() string {
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
func Init() {
// TODO: move common flag processing to cmd/utils
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
flag.PrintDefaults()
......@@ -108,30 +111,51 @@ func Init() {
flag.StringVar(&Identifier, "id", "", "Custom client identifier")
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
flag.BoolVar(&UseUPnP, "upnp", true, "enable UPnP support")
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
flag.BoolVar(&StartRpc, "rpc", true, "start rpc server")
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
flag.StringVar(&SeedNode, "seednode", "poc-8.ethdev.com:30303", "ip:port of seed node to connect to. Set to blank for skip")
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
flag.StringVar(&NatType, "nat", "", "NAT support (UPNP|PMP) (none)")
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
flag.StringVar(&LogFile, "logfile", "", "log file (defaults to standard output)")
flag.StringVar(&Datadir, "datadir", defaultDataDir(), "specifies the datadir to use")
flag.StringVar(&PMPGateway, "pmp", "", "Gateway IP for PMP")
flag.StringVar(&ConfigFile, "conf", defaultConfigFile, "config file")
flag.StringVar(&DebugFile, "debug", "", "debug file (no debugging if not set)")
flag.IntVar(&LogLevel, "loglevel", int(logger.InfoLevel), "loglevel: 0-5: silent,error,warn,info,debug,debug detail)")
flag.StringVar(&AssetPath, "asset_path", defaultAssetPath(), "absolute path to GUI assets directory")
// Network stuff
var (
nodeKeyFile = flag.String("nodekey", "", "network private key file")
nodeKeyHex = flag.String("nodekeyhex", "", "network private key (for testing)")
natstr = flag.String("nat", "any", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
)
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
flag.StringVar(&BootNodes, "bootnodes", "", "space-separated node URLs for discovery bootstrap")
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
flag.Parse()
var err error
if NAT, err = nat.Parse(*natstr); err != nil {
log.Fatalf("-nat: %v", err)
}
switch {
case *nodeKeyFile != "" && *nodeKeyHex != "":
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
case *nodeKeyFile != "":
if NodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
log.Fatalf("-nodekey: %v", err)
}
case *nodeKeyHex != "":
if NodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
log.Fatalf("-nodekeyhex: %v", err)
}
}
if VmType >= int(vm.MaxVmTy) {
log.Fatal("Invalid VM type ", VmType)
}
......
......@@ -41,7 +41,6 @@ import (
"github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/miner"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/ui/qt/qwhisper"
"github.com/ethereum/go-ethereum/xeth"
"github.com/obscuren/qml"
......@@ -77,9 +76,8 @@ type Gui struct {
xeth *xeth.XEth
Session string
clientIdentity *p2p.SimpleClientIdentity
config *ethutil.ConfigManager
Session string
config *ethutil.ConfigManager
plugins map[string]plugin
......@@ -87,7 +85,7 @@ type Gui struct {
}
// Create GUI, but doesn't start it
func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, clientIdentity *p2p.SimpleClientIdentity, session string, logLevel int) *Gui {
func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, session string, logLevel int) *Gui {
db, err := ethdb.NewLDBDatabase("tx_database")
if err != nil {
panic(err)
......@@ -95,15 +93,14 @@ func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, clientIden
xeth := xeth.New(ethereum)
gui := &Gui{eth: ethereum,
txDb: db,
xeth: xeth,
logLevel: logger.LogLevel(logLevel),
Session: session,
open: false,
clientIdentity: clientIdentity,
config: config,
plugins: make(map[string]plugin),
serviceEvents: make(chan ServEv, 1),
txDb: db,
xeth: xeth,
logLevel: logger.LogLevel(logLevel),
Session: session,
open: false,
config: config,
plugins: make(map[string]plugin),
serviceEvents: make(chan ServEv, 1),
}
data, _ := ethutil.ReadAllFile(path.Join(ethutil.Config.ExecPath, "plugins.json"))
json.Unmarshal([]byte(data), &gui.plugins)
......
......@@ -52,19 +52,18 @@ func run() error {
config := utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
ethereum, err := eth.New(&eth.Config{
Name: ClientIdentifier,
Version: Version,
KeyStore: KeyStore,
DataDir: Datadir,
LogFile: LogFile,
LogLevel: LogLevel,
Identifier: Identifier,
MaxPeers: MaxPeer,
Port: OutboundPort,
NATType: PMPGateway,
PMPGateway: PMPGateway,
KeyRing: KeyRing,
Dial: true,
Name: p2p.MakeName(ClientIdentifier, Version),
KeyStore: KeyStore,
DataDir: Datadir,
LogFile: LogFile,
LogLevel: LogLevel,
MaxPeers: MaxPeer,
Port: OutboundPort,
NAT: NAT,
BootNodes: BootNodes,
NodeKey: NodeKey,
KeyRing: KeyRing,
Dial: true,
})
if err != nil {
mainlogger.Fatalln(err)
......@@ -79,12 +78,12 @@ func run() error {
utils.StartWebSockets(ethereum, WsPort)
}
gui := NewWindow(ethereum, config, ethereum.ClientIdentity().(*p2p.SimpleClientIdentity), KeyRing, LogLevel)
gui := NewWindow(ethereum, config, KeyRing, LogLevel)
utils.RegisterInterrupt(func(os.Signal) {
gui.Stop()
})
go utils.StartEthereum(ethereum, SeedNode)
go utils.StartEthereum(ethereum)
fmt.Println("ETH stack took", time.Since(tstart))
......
......@@ -136,15 +136,15 @@ func (ui *UiLib) Muted(content string) {
func (ui *UiLib) Connect(button qml.Object) {
if !ui.connected {
ui.eth.Start(SeedNode)
ui.eth.Start()
ui.connected = true
button.Set("enabled", false)
}
}
func (ui *UiLib) ConnectToPeer(addr string) {
if err := ui.eth.SuggestPeer(addr); err != nil {
guilogger.Infoln(err)
func (ui *UiLib) ConnectToPeer(nodeURL string) {
if err := ui.eth.SuggestPeer(nodeURL); err != nil {
guilogger.Infoln("SuggestPeer error: " + err.Error())
}
}
......
......@@ -121,13 +121,11 @@ func exit(err error) {
os.Exit(status)
}
func StartEthereum(ethereum *eth.Ethereum, SeedNode string) {
clilogger.Infof("Starting %s", ethereum.ClientIdentity())
err := ethereum.Start(SeedNode)
if err != nil {
func StartEthereum(ethereum *eth.Ethereum) {
clilogger.Infoln("Starting ", ethereum.Name())
if err := ethereum.Start(); err != nil {
exit(err)
}
RegisterInterrupt(func(sig os.Signal) {
ethereum.Stop()
logger.Flush()
......
......@@ -9,7 +9,6 @@ import (
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/p2p"
)
// Implement our EthTest Manager
......@@ -54,13 +53,6 @@ func (tm *TestManager) TxPool() *TxPool {
func (tm *TestManager) EventMux() *event.TypeMux {
return tm.eventMux
}
func (tm *TestManager) Broadcast(msgType p2p.Msg, data []interface{}) {
fmt.Println("Broadcast not implemented")
}
func (tm *TestManager) ClientIdentity() p2p.ClientIdentity {
return nil
}
func (tm *TestManager) KeyManager() *crypto.KeyManager {
return nil
}
......
......@@ -16,7 +16,6 @@ type EthManager interface {
IsListening() bool
Peers() []*p2p.Peer
KeyManager() *crypto.KeyManager
ClientIdentity() p2p.ClientIdentity
Db() ethutil.Database
EventMux() *event.TypeMux
}
......@@ -8,6 +8,8 @@ import (
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"os"
"encoding/hex"
"encoding/json"
......@@ -27,10 +29,11 @@ func init() {
ecies.AddParamsForCurve(S256(), ecies.ECIES_AES128_SHA256)
}
func Sha3(data []byte) []byte {
func Sha3(data ...[]byte) []byte {
d := sha3.NewKeccak256()
d.Write(data)
for _, b := range data {
d.Write(b)
}
return d.Sum(nil)
}
......@@ -98,6 +101,32 @@ func FromECDSAPub(pub *ecdsa.PublicKey) []byte {
return elliptic.Marshal(S256(), pub.X, pub.Y)
}
// HexToECDSA parses a secp256k1 private key.
func HexToECDSA(hexkey string) (*ecdsa.PrivateKey, error) {
b, err := hex.DecodeString(hexkey)
if err != nil {
return nil, errors.New("invalid hex string")
}
if len(b) != 32 {
return nil, errors.New("invalid length, need 256 bits")
}
return ToECDSA(b), nil
}
// LoadECDSA loads a secp256k1 private key from the given file.
func LoadECDSA(file string) (*ecdsa.PrivateKey, error) {
buf := make([]byte, 32)
fd, err := os.Open(file)
if err != nil {
return nil, err
}
defer fd.Close()
if _, err := io.ReadFull(fd, buf); err != nil {
return nil, err
}
return ToECDSA(buf), nil
}
func GenerateKey() (*ecdsa.PrivateKey, error) {
return ecdsa.GenerateKey(S256(), rand.Reader)
}
......
......@@ -18,7 +18,7 @@ import (
func TestSha3(t *testing.T) {
msg := []byte("abc")
exp, _ := hex.DecodeString("4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45")
checkhash(t, "Sha3-256", Sha3, msg, exp)
checkhash(t, "Sha3-256", func(in []byte) []byte { return Sha3(in) }, msg, exp)
}
func TestSha256(t *testing.T) {
......
......@@ -25,11 +25,12 @@ package crypto
import (
"bytes"
"code.google.com/p/go-uuid/uuid"
"crypto/ecdsa"
"crypto/elliptic"
"encoding/json"
"io"
"code.google.com/p/go-uuid/uuid"
)
type Key struct {
......
package eth
import (
"crypto/ecdsa"
"fmt"
"net"
"sync"
"strings"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/crypto"
......@@ -12,35 +12,58 @@ import (
"github.com/ethereum/go-ethereum/event"
ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/nat"
"github.com/ethereum/go-ethereum/pow/ezp"
"github.com/ethereum/go-ethereum/rpc"
"github.com/ethereum/go-ethereum/whisper"
)
type Config struct {
Name string
Version string
Identifier string
KeyStore string
DataDir string
LogFile string
LogLevel int
LogFormat string
KeyRing string
MaxPeers int
Port string
NATType string
PMPGateway string
var logger = ethlogger.NewLogger("SERV")
var jsonlogger = ethlogger.NewJsonLogger()
type Config struct {
Name string
KeyStore string
DataDir string
LogFile string
LogLevel int
KeyRing string
LogFormat string
MaxPeers int
Port string
// This should be a space-separated list of
// discovery node URLs.
BootNodes string
// This key is used to identify the node on the network.
// If nil, an ephemeral key is used.
NodeKey *ecdsa.PrivateKey
NAT nat.Interface
Shh bool
Dial bool
KeyManager *crypto.KeyManager
}
var logger = ethlogger.NewLogger("SERV")
var jsonlogger = ethlogger.NewJsonLogger()
func (cfg *Config) parseBootNodes() []*discover.Node {
var ns []*discover.Node
for _, url := range strings.Split(cfg.BootNodes, " ") {
if url == "" {
continue
}
n, err := discover.ParseNode(url)
if err != nil {
logger.Errorf("Bootstrap URL %s: %v\n", url, err)
continue
}
ns = append(ns, n)
}
return ns
}
type Ethereum struct {
// Channel for shutting down the ethereum
......@@ -68,11 +91,7 @@ type Ethereum struct {
WsServer rpc.RpcServer
keyManager *crypto.KeyManager
clientIdentity p2p.ClientIdentity
logger ethlogger.LogSystem
synclock sync.Mutex
syncGroup sync.WaitGroup
logger ethlogger.LogSystem
Mining bool
}
......@@ -105,21 +124,17 @@ func New(config *Config) (*Ethereum, error) {
// Initialise the keyring
keyManager.Init(config.KeyRing, 0, false)
// Create a new client id for this instance. This will help identifying the node on the network
clientId := p2p.NewSimpleClientIdentity(config.Name, config.Version, config.Identifier, keyManager.PublicKey())
saveProtocolVersion(db)
//ethutil.Config.Db = db
eth := &Ethereum{
shutdownChan: make(chan bool),
quit: make(chan bool),
db: db,
keyManager: keyManager,
clientIdentity: clientId,
blacklist: p2p.NewBlacklist(),
eventMux: &event.TypeMux{},
logger: logger,
shutdownChan: make(chan bool),
quit: make(chan bool),
db: db,
keyManager: keyManager,
blacklist: p2p.NewBlacklist(),
eventMux: &event.TypeMux{},
logger: logger,
}
eth.chainManager = core.NewChainManager(db, eth.EventMux())
......@@ -134,21 +149,22 @@ func New(config *Config) (*Ethereum, error) {
ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool)
protocols := []p2p.Protocol{ethProto, eth.whisper.Protocol()}
nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway)
if err != nil {
return nil, err
netprv := config.NodeKey
if netprv == nil {
if netprv, err = crypto.GenerateKey(); err != nil {
return nil, fmt.Errorf("could not generate server key: %v", err)
}
}
eth.net = &p2p.Server{
Identity: clientId,
MaxPeers: config.MaxPeers,
Protocols: protocols,
Blacklist: eth.blacklist,
NAT: nat,
NoDial: !config.Dial,
PrivateKey: netprv,
Name: config.Name,
MaxPeers: config.MaxPeers,
Protocols: protocols,
Blacklist: eth.blacklist,
NAT: config.NAT,
NoDial: !config.Dial,
BootstrapNodes: config.parseBootNodes(),
}
if len(config.Port) > 0 {
eth.net.ListenAddr = ":" + config.Port
}
......@@ -164,8 +180,8 @@ func (s *Ethereum) Logger() ethlogger.LogSystem {
return s.logger
}
func (s *Ethereum) ClientIdentity() p2p.ClientIdentity {
return s.clientIdentity
func (s *Ethereum) Name() string {
return s.net.Name
}
func (s *Ethereum) ChainManager() *core.ChainManager {
......@@ -221,12 +237,12 @@ func (s *Ethereum) Coinbase() []byte {
}
// Start the ethereum
func (s *Ethereum) Start(seedNode string) error {
func (s *Ethereum) Start() error {
jsonlogger.LogJson(&ethlogger.LogStarting{
ClientString: s.ClientIdentity().String(),
ClientString: s.net.Name,
Coinbase: ethutil.Bytes2Hex(s.KeyManager().Address()),
ProtocolVersion: ProtocolVersion,
LogEvent: ethlogger.LogEvent{Guid: ethutil.Bytes2Hex(s.ClientIdentity().Pubkey())},
LogEvent: ethlogger.LogEvent{Guid: ethutil.Bytes2Hex(crypto.FromECDSAPub(&s.net.PrivateKey.PublicKey))},
})
err := s.net.Start()
......@@ -250,26 +266,16 @@ func (s *Ethereum) Start(seedNode string) error {
s.blockSub = s.eventMux.Subscribe(core.NewMinedBlockEvent{})
go s.blockBroadcastLoop()
// TODO: read peers here
if len(seedNode) > 0 {
logger.Infof("Connect to seed node %v", seedNode)
if err := s.SuggestPeer(seedNode); err != nil {
logger.Infoln(err)
}
}
logger.Infoln("Server started")
return nil
}
func (self *Ethereum) SuggestPeer(addr string) error {
netaddr, err := net.ResolveTCPAddr("tcp", addr)
func (self *Ethereum) SuggestPeer(nodeURL string) error {
n, err := discover.ParseNode(nodeURL)
if err != nil {
logger.Errorf("couldn't resolve %s:", addr, err)
return err
return fmt.Errorf("invalid node URL: %v", err)
}
self.net.SuggestPeer(netaddr.IP, netaddr.Port, nil)
self.net.SuggestPeer(n)
return nil
}
......
......@@ -92,13 +92,14 @@ func EthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool)
// the main loop that handles incoming messages
// note RemovePeer in the post-disconnect hook
func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool, peer *p2p.Peer, rw p2p.MsgReadWriter) (err error) {
id := peer.ID()
self := &ethProtocol{
txPool: txPool,
chainManager: chainManager,
blockPool: blockPool,
rw: rw,
peer: peer,
id: fmt.Sprintf("%x", peer.Identity().Pubkey()[:8]),
id: fmt.Sprintf("%x", id[:8]),
}
err = self.handleStatus()
if err == nil {
......
......@@ -14,6 +14,7 @@ import (
"github.com/ethereum/go-ethereum/ethutil"
ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
)
var sys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel))
......@@ -128,26 +129,11 @@ func (self *testBlockPool) RemovePeer(peerId string) {
}
}
// TODO: refactor this into p2p/client_identity
type peerId struct {
pubkey []byte
}
func (self *peerId) String() string {
return "test peer"
}
func (self *peerId) Pubkey() (pubkey []byte) {
pubkey = self.pubkey
if len(pubkey) == 0 {
pubkey = crypto.GenerateNewKeyPair().PublicKey
self.pubkey = pubkey
}
return
}
func testPeer() *p2p.Peer {
return p2p.NewPeer(&peerId{}, []p2p.Cap{})
var id discover.NodeID
pk := crypto.GenerateNewKeyPair().PublicKey
copy(id[:], pk)
return p2p.NewPeer(id, "test peer", []p2p.Cap{})
}
type ethProtocolTester struct {
......
......@@ -197,12 +197,13 @@ func (self *JSRE) watch(call otto.FunctionCall) otto.Value {
}
func (self *JSRE) addPeer(call otto.FunctionCall) otto.Value {
host, err := call.Argument(0).ToString()
nodeURL, err := call.Argument(0).ToString()
if err != nil {
return otto.FalseValue()
}
self.ethereum.SuggestPeer(host)
if err := self.ethereum.SuggestPeer(nodeURL); err != nil {
return otto.FalseValue()
}
return otto.TrueValue()
}
......
package p2p
import (
"fmt"
"runtime"
)
// ClientIdentity represents the identity of a peer.
type ClientIdentity interface {
String() string // human readable identity
Pubkey() []byte // 512-bit public key
}
type SimpleClientIdentity struct {
clientIdentifier string
version string
customIdentifier string
os string
implementation string
pubkey []byte
}
func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity {
clientIdentity := &SimpleClientIdentity{
clientIdentifier: clientIdentifier,
version: version,
customIdentifier: customIdentifier,
os: runtime.GOOS,
implementation: runtime.Version(),
pubkey: pubkey,
}
return clientIdentity
}
func (c *SimpleClientIdentity) init() {
}
func (c *SimpleClientIdentity) String() string {
var id string
if len(c.customIdentifier) > 0 {
id = "/" + c.customIdentifier
}
return fmt.Sprintf("%s/v%s%s/%s/%s",
c.clientIdentifier,
c.version,
id,
c.os,
c.implementation)
}
func (c *SimpleClientIdentity) Pubkey() []byte {
return []byte(c.pubkey)
}
func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
c.customIdentifier = customIdentifier
}
func (c *SimpleClientIdentity) GetCustomIdentifier() string {
return c.customIdentifier
}
package p2p
import (
"fmt"
"runtime"
"testing"
)
func TestClientIdentity(t *testing.T) {
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
clientString := clientIdentity.String()
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
customIdentifier := clientIdentity.GetCustomIdentifier()
if customIdentifier != "test" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
}
clientIdentity.SetCustomIdentifier("test2")
customIdentifier = clientIdentity.GetCustomIdentifier()
if customIdentifier != "test2" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
}
clientString = clientIdentity.String()
expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
}
This diff is collapsed.
package p2p
import (
"bytes"
"crypto/ecdsa"
"crypto/rand"
"net"
"testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/obscuren/ecies"
)
func TestPublicKeyEncoding(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
pub0s := crypto.FromECDSAPub(pub0)
pub1, err := importPublicKey(pub0s)
if err != nil {
t.Errorf("%v", err)
}
eciesPub1 := ecies.ImportECDSAPublic(pub1)
if eciesPub1 == nil {
t.Errorf("invalid ecdsa public key")
}
pub1s, err := exportPublicKey(pub1)
if err != nil {
t.Errorf("%v", err)
}
if len(pub1s) != 64 {
t.Errorf("wrong length expect 64, got", len(pub1s))
}
pub2, err := importPublicKey(pub1s)
if err != nil {
t.Errorf("%v", err)
}
pub2s, err := exportPublicKey(pub2)
if err != nil {
t.Errorf("%v", err)
}
if !bytes.Equal(pub1s, pub2s) {
t.Errorf("exports dont match")
}
pub2sEC := crypto.FromECDSAPub(pub2)
if !bytes.Equal(pub0s, pub2sEC) {
t.Errorf("exports dont match")
}
}
func TestSharedSecret(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
prv1, _ := crypto.GenerateKey()
pub1 := &prv1.PublicKey
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
if err != nil {
return
}
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
if err != nil {
return
}
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
if !bytes.Equal(ss0, ss1) {
t.Errorf("dont match :(")
}
}
func TestCryptoHandshake(t *testing.T) {
testCryptoHandshake(newkey(), newkey(), nil, t)
}
func TestCryptoHandshakeWithToken(t *testing.T) {
sessionToken := make([]byte, shaLen)
rand.Read(sessionToken)
testCryptoHandshake(newkey(), newkey(), sessionToken, t)
}
func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) {
var err error
// pub0 := &prv0.PublicKey
pub1 := &prv1.PublicKey
// pub0s := crypto.FromECDSAPub(pub0)
pub1s := crypto.FromECDSAPub(pub1)
// simulate handshake by feeding output to input
// initiator sends handshake 'auth'
auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken)
if err != nil {
t.Errorf("%v", err)
}
t.Logf("-> %v", hexkey(auth))
// receiver reads auth and responds with response
response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
if err != nil {
t.Errorf("%v", err)
}
t.Logf("<- %v\n", hexkey(response))
// initiator reads receiver's response and the key exchange completes
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
if err != nil {
t.Errorf("completeHandshake error: %v", err)
}
// now both parties should have the same session parameters
initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
if err != nil {
t.Errorf("newSession error: %v", err)
}
recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey)
if err != nil {
t.Errorf("newSession error: %v", err)
}
// fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response)
// fmt.Printf("\nauth %x\ninitNonce %x\nresponse%x\nremoteRecNonce %x\nremoteInitNonce %x\nremoteRandomPubKey %x\nrecNonce %x\nremoteInitRandomPubKey %x\ninitSessionToken %x\n\n", auth, initNonce, response, remoteRecNonce, remoteInitNonce, remoteRandomPubKey, recNonce, remoteInitRandomPubKey, initSessionToken)
if !bytes.Equal(initNonce, remoteInitNonce) {
t.Errorf("nonces do not match")
}
if !bytes.Equal(recNonce, remoteRecNonce) {
t.Errorf("receiver nonces do not match")
}
if !bytes.Equal(initSessionToken, recSessionToken) {
t.Errorf("session tokens do not match")
}
}
func TestHandshake(t *testing.T) {
defer testlog(t).detach()
prv0, _ := crypto.GenerateKey()
prv1, _ := crypto.GenerateKey()
pub0s, _ := exportPublicKey(&prv0.PublicKey)
pub1s, _ := exportPublicKey(&prv1.PublicKey)
rw0, rw1 := net.Pipe()
tokens := make(chan []byte)
go func() {
token, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
if err != nil {
t.Errorf("outbound side error: %v", err)
}
tokens <- token
}()
go func() {
token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil)
if err != nil {
t.Errorf("inbound side error: %v", err)
}
if !bytes.Equal(remotePubkey, pub0s) {
t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s)
}
tokens <- token
}()
t1, t2 := <-tokens, <-tokens
if !bytes.Equal(t1, t2) {
t.Error("session token mismatch")
}
}
package discover
import (
"crypto/ecdsa"
"crypto/elliptic"
"encoding/hex"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/rlp"
)
const nodeIDBits = 512
// Node represents a host on the network.
type Node struct {
ID NodeID
IP net.IP
DiscPort int // UDP listening port for discovery protocol
TCPPort int // TCP listening port for RLPx
active time.Time
}
func newNode(id NodeID, addr *net.UDPAddr) *Node {
return &Node{
ID: id,
IP: addr.IP,
DiscPort: addr.Port,
TCPPort: addr.Port,
active: time.Now(),
}
}
func (n *Node) isValid() bool {
// TODO: don't accept localhost, LAN addresses from internet hosts
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
}
// The string representation of a Node is a URL.
// Please see ParseNode for a description of the format.
func (n *Node) String() string {
addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
u := url.URL{
Scheme: "enode",
User: url.User(fmt.Sprintf("%x", n.ID[:])),
Host: addr.String(),
}
if n.DiscPort != n.TCPPort {
u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
}
return u.String()
}
// ParseNode parses a node URL.
//
// A node URL has scheme "enode".
//
// The hexadecimal node ID is encoded in the username portion of the
// URL, separated from the host by an @ sign. The hostname can only be
// given as an IP address, DNS domain names are not allowed. The port
// in the host name section is the TCP listening port. If the TCP and
// UDP (discovery) ports differ, the UDP port is specified as query
// parameter "discport".
//
// In the following example, the node URL describes
// a node with IP address 10.3.58.6, TCP listening port 30303
// and UDP discovery port 30301.
//
// enode://<hex node id>@10.3.58.6:30303?discport=30301
func ParseNode(rawurl string) (*Node, error) {
var n Node
u, err := url.Parse(rawurl)
if u.Scheme != "enode" {
return nil, errors.New("invalid URL scheme, want \"enode\"")
}
if u.User == nil {
return nil, errors.New("does not contain node ID")
}
if n.ID, err = HexID(u.User.String()); err != nil {
return nil, fmt.Errorf("invalid node ID (%v)", err)
}
ip, port, err := net.SplitHostPort(u.Host)
if err != nil {
return nil, fmt.Errorf("invalid host: %v", err)
}
if n.IP = net.ParseIP(ip); n.IP == nil {
return nil, errors.New("invalid IP address")
}
if n.TCPPort, err = strconv.Atoi(port); err != nil {
return nil, errors.New("invalid port")
}
qv := u.Query()
if qv.Get("discport") == "" {
n.DiscPort = n.TCPPort
} else {
if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
return nil, errors.New("invalid discport in query")
}
}
return &n, nil
}
// MustParseNode parses a node URL. It panics if the URL is not valid.
func MustParseNode(rawurl string) *Node {
n, err := ParseNode(rawurl)
if err != nil {
panic("invalid node URL: " + err.Error())
}
return n
}
func (n Node) EncodeRLP(w io.Writer) error {
return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
}
func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
var ext rpcNode
if err = s.Decode(&ext); err == nil {
n.TCPPort = int(ext.Port)
n.DiscPort = int(ext.Port)
n.ID = ext.ID
if n.IP = net.ParseIP(ext.IP); n.IP == nil {
return errors.New("invalid IP string")
}
}
return err
}
// NodeID is a unique identifier for each node.
// The node identifier is a marshaled elliptic curve public key.
type NodeID [nodeIDBits / 8]byte
// NodeID prints as a long hexadecimal number.
func (n NodeID) String() string {
return fmt.Sprintf("%#x", n[:])
}
// The Go syntax representation of a NodeID is a call to HexID.
func (n NodeID) GoString() string {
return fmt.Sprintf("discover.HexID(\"%#x\")", n[:])
}
// HexID converts a hex string to a NodeID.
// The string may be prefixed with 0x.
func HexID(in string) (NodeID, error) {
if strings.HasPrefix(in, "0x") {
in = in[2:]
}
var id NodeID
b, err := hex.DecodeString(in)
if err != nil {
return id, err
} else if len(b) != len(id) {
return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
}
copy(id[:], b)
return id, nil
}
// MustHexID converts a hex string to a NodeID.
// It panics if the string is not a valid NodeID.
func MustHexID(in string) NodeID {
id, err := HexID(in)
if err != nil {
panic(err)
}
return id
}
// PubkeyID returns a marshaled representation of the given public key.
func PubkeyID(pub *ecdsa.PublicKey) NodeID {
var id NodeID
pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
if len(pbytes)-1 != len(id) {
panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
}
copy(id[:], pbytes[1:])
return id
}
// recoverNodeID computes the public key used to sign the
// given hash from the signature.
func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
pubkey, err := secp256k1.RecoverPubkey(hash, sig)
if err != nil {
return id, err
}
if len(pubkey)-1 != len(id) {
return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
}
for i := range id {
id[i] = pubkey[i+1]
}
return id, nil
}
// distcmp compares the distances a->target and b->target.
// Returns -1 if a is closer to target, 1 if b is closer to target
// and 0 if they are equal.
func distcmp(target, a, b NodeID) int {
for i := range target {
da := a[i] ^ target[i]
db := b[i] ^ target[i]
if da > db {
return 1
} else if da < db {
return -1
}
}
return 0
}
// table of leading zero counts for bytes [0..255]
var lzcount = [256]int{
8, 7, 6, 6, 5, 5, 5, 5,
4, 4, 4, 4, 4, 4, 4, 4,
3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
}
// logdist returns the logarithmic distance between a and b, log2(a ^ b).
func logdist(a, b NodeID) int {
lz := 0
for i := range a {
x := a[i] ^ b[i]
if x == 0 {
lz += 8
} else {
lz += lzcount[x]
break
}
}
return len(a)*8 - lz
}
// randomID returns a random NodeID such that logdist(a, b) == n
func randomID(a NodeID, n int) (b NodeID) {
if n == 0 {
return a
}
// flip bit at position n, fill the rest with random bits
b = a
pos := len(a) - n/8 - 1
bit := byte(0x01) << (byte(n%8) - 1)
if bit == 0 {
pos++
bit = 0x80
}
b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
for i := pos + 1; i < len(a); i++ {
b[i] = byte(rand.Intn(255))
}
return b
}
package discover
import (
"math/big"
"math/rand"
"net"
"reflect"
"testing"
"testing/quick"
"time"
"github.com/ethereum/go-ethereum/crypto"
)
var (
quickrand = rand.New(rand.NewSource(time.Now().Unix()))
quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
)
var parseNodeTests = []struct {
rawurl string
wantError string
wantResult *Node
}{
{
rawurl: "http://foobar",
wantError: `invalid URL scheme, want "enode"`,
},
{
rawurl: "enode://foobar",
wantError: `does not contain node ID`,
},
{
rawurl: "enode://01010101@123.124.125.126:3",
wantError: `invalid node ID (wrong length, need 64 hex bytes)`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
wantError: `invalid IP address`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
wantError: `invalid port`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
wantError: `invalid discport in query`,
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
wantResult: &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.ParseIP("127.0.0.1"),
DiscPort: 52150,
TCPPort: 52150,
},
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
wantResult: &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.ParseIP("::"),
DiscPort: 52150,
TCPPort: 52150,
},
},
{
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344",
wantResult: &Node{
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
IP: net.ParseIP("127.0.0.1"),
DiscPort: 223344,
TCPPort: 52150,
},
},
}
func TestParseNode(t *testing.T) {
for i, test := range parseNodeTests {
n, err := ParseNode(test.rawurl)
if err == nil && test.wantError != "" {
t.Errorf("test %d: got nil error, expected %#q", i, test.wantError)
continue
}
if err != nil && err.Error() != test.wantError {
t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError)
continue
}
if !reflect.DeepEqual(n, test.wantResult) {
t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult)
}
}
}
func TestNodeString(t *testing.T) {
for i, test := range parseNodeTests {
if test.wantError != "" {
continue
}
str := test.wantResult.String()
if str != test.rawurl {
t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
}
}
}
func TestHexID(t *testing.T) {
ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
if id1 != ref {
t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
}
if id2 != ref {
t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
}
}
func TestNodeID_recover(t *testing.T) {
prv := newkey()
hash := make([]byte, 32)
sig, err := crypto.Sign(hash, prv)
if err != nil {
t.Fatalf("signing error: %v", err)
}
pub := PubkeyID(&prv.PublicKey)
recpub, err := recoverNodeID(hash, sig)
if err != nil {
t.Fatalf("recovery error: %v", err)
}
if pub != recpub {
t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
}
}
func TestNodeID_distcmp(t *testing.T) {
distcmpBig := func(target, a, b NodeID) int {
tbig := new(big.Int).SetBytes(target[:])
abig := new(big.Int).SetBytes(a[:])
bbig := new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
}
if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
t.Error(err)
}
}
// the random tests is likely to miss the case where they're equal.
func TestNodeID_distcmpEqual(t *testing.T) {
base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
if distcmp(base, x, x) != 0 {
t.Errorf("distcmp(base, x, x) != 0")
}
}
func TestNodeID_logdist(t *testing.T) {
logdistBig := func(a, b NodeID) int {
abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
return new(big.Int).Xor(abig, bbig).BitLen()
}
if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
t.Error(err)
}
}
// the random tests is likely to miss the case where they're equal.
func TestNodeID_logdistEqual(t *testing.T) {
x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
if logdist(x, x) != 0 {
t.Errorf("logdist(x, x) != 0")
}
}
func TestNodeID_randomID(t *testing.T) {
// we don't use quick.Check here because its output isn't
// very helpful when the test fails.
for i := 0; i < quickcfg.MaxCount; i++ {
a := gen(NodeID{}, quickrand).(NodeID)
dist := quickrand.Intn(len(NodeID{}) * 8)
result := randomID(a, dist)
actualdist := logdist(result, a)
if dist != actualdist {
t.Log("a: ", a)
t.Log("result:", result)
t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
}
}
}
func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
var id NodeID
m := rand.Intn(len(id))
for i := len(id) - 1; i > m; i-- {
id[i] = byte(rand.Uint32())
}
return reflect.ValueOf(id)
}
// Package discover implements the Node Discovery Protocol.
//
// The Node Discovery protocol provides a way to find RLPx nodes that
// can be connected to. It uses a Kademlia-like protocol to maintain a
// distributed database of the IDs and endpoints of all listening
// nodes.
package discover
import (
"net"
"sort"
"sync"
"time"
)
const (
alpha = 3 // Kademlia concurrency factor
bucketSize = 16 // Kademlia bucket size
nBuckets = nodeIDBits + 1 // Number of buckets
)
type Table struct {
mutex sync.Mutex // protects buckets, their content, and nursery
buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*Node // bootstrap nodes
net transport
self *Node // metadata of the local node
}
// transport is implemented by the UDP transport.
// it is an interface so we can test without opening lots of UDP
// sockets and without generating a private key.
type transport interface {
ping(*Node) error
findnode(e *Node, target NodeID) ([]*Node, error)
close()
}
// bucket contains nodes, ordered by their last activity.
type bucket struct {
lastLookup time.Time
entries []*Node
}
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
tab := &Table{net: t, self: newNode(ourID, ourAddr)}
for i := range tab.buckets {
tab.buckets[i] = new(bucket)
}
return tab
}
// Self returns the local node ID.
func (tab *Table) Self() NodeID {
return tab.self.ID
}
// Close terminates the network listener.
func (tab *Table) Close() {
tab.net.close()
}
// Bootstrap sets the bootstrap nodes. These nodes are used to connect
// to the network if the table is empty. Bootstrap will also attempt to
// fill the table by performing random lookup operations on the
// network.
func (tab *Table) Bootstrap(nodes []*Node) {
tab.mutex.Lock()
// TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes
tab.nursery = make([]*Node, 0, len(nodes))
for _, n := range nodes {
cpy := *n
tab.nursery = append(tab.nursery, &cpy)
}
tab.mutex.Unlock()
tab.refresh()
}
// Lookup performs a network search for nodes close
// to the given target. It approaches the target by querying
// nodes that are closer to it on each iteration.
func (tab *Table) Lookup(target NodeID) []*Node {
var (
asked = make(map[NodeID]bool)
seen = make(map[NodeID]bool)
reply = make(chan []*Node, alpha)
pendingQueries = 0
)
// don't query further if we hit the target or ourself.
// unlikely to happen often in practice.
asked[target] = true
asked[tab.self.ID] = true
tab.mutex.Lock()
// update last lookup stamp (for refresh logic)
tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now()
// generate initial result set
result := tab.closest(target, bucketSize)
tab.mutex.Unlock()
for {
// ask the alpha closest nodes that we haven't asked yet
for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
n := result.entries[i]
if !asked[n.ID] {
asked[n.ID] = true
pendingQueries++
go func() {
result, _ := tab.net.findnode(n, target)
reply <- result
}()
}
}
if pendingQueries == 0 {
// we have asked all closest nodes, stop the search
break
}
// wait for the next reply
for _, n := range <-reply {
cn := n
if !seen[n.ID] {
seen[n.ID] = true
result.push(cn, bucketSize)
}
}
pendingQueries--
}
return result.entries
}
// refresh performs a lookup for a random target to keep buckets full.
func (tab *Table) refresh() {
ld := -1 // logdist of chosen bucket
tab.mutex.Lock()
for i, b := range tab.buckets {
if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) {
ld = i
break
}
}
tab.mutex.Unlock()
result := tab.Lookup(randomID(tab.self.ID, ld))
if len(result) == 0 {
// bootstrap the table with a self lookup
tab.mutex.Lock()
tab.add(tab.nursery)
tab.mutex.Unlock()
tab.Lookup(tab.self.ID)
// TODO: the Kademlia paper says that we're supposed to perform
// random lookups in all buckets further away than our closest neighbor.
}
}
// closest returns the n nodes in the table that are closest to the
// given id. The caller must hold tab.mutex.
func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance {
// This is a very wasteful way to find the closest nodes but
// obviously correct. I believe that tree-based buckets would make
// this easier to implement efficiently.
close := &nodesByDistance{target: target}
for _, b := range tab.buckets {
for _, n := range b.entries {
close.push(n, nresults)
}
}
return close
}
func (tab *Table) len() (n int) {
for _, b := range tab.buckets {
n += len(b.entries)
}
return n
}
// bumpOrAdd updates the activity timestamp for the given node and
// attempts to insert the node into a bucket. The returned Node might
// not be part of the table. The caller must hold tab.mutex.
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
b := tab.buckets[logdist(tab.self.ID, node)]
if n = b.bump(node); n == nil {
n = newNode(node, from)
if len(b.entries) == bucketSize {
tab.pingReplace(n, b)
} else {
b.entries = append(b.entries, n)
}
}
return n
}
func (tab *Table) pingReplace(n *Node, b *bucket) {
old := b.entries[bucketSize-1]
go func() {
if err := tab.net.ping(old); err == nil {
// it responded, we don't need to replace it.
return
}
// it didn't respond, replace the node if it is still the oldest node.
tab.mutex.Lock()
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
// slide down other entries and put the new one in front.
// TODO: insert in correct position to keep the order
copy(b.entries[1:], b.entries)
b.entries[0] = n
}
tab.mutex.Unlock()
}()
}
// bump updates the activity timestamp for the given node.
// The caller must hold tab.mutex.
func (tab *Table) bump(node NodeID) {
tab.buckets[logdist(tab.self.ID, node)].bump(node)
}
// add puts the entries into the table if their corresponding
// bucket is not full. The caller must hold tab.mutex.
func (tab *Table) add(entries []*Node) {
outer:
for _, n := range entries {
if n == nil || n.ID == tab.self.ID {
// skip bad entries. The RLP decoder returns nil for empty
// input lists.
continue
}
bucket := tab.buckets[logdist(tab.self.ID, n.ID)]
for i := range bucket.entries {
if bucket.entries[i].ID == n.ID {
// already in bucket
continue outer
}
}
if len(bucket.entries) < bucketSize {
bucket.entries = append(bucket.entries, n)
}
}
}
func (b *bucket) bump(id NodeID) *Node {
for i, n := range b.entries {
if n.ID == id {
n.active = time.Now()
// move it to the front
copy(b.entries[1:], b.entries[:i+1])
b.entries[0] = n
return n
}
}
return nil
}
// nodesByDistance is a list of nodes, ordered by
// distance to target.
type nodesByDistance struct {
entries []*Node
target NodeID
}
// push adds the given node to the list, keeping the total size below maxElems.
func (h *nodesByDistance) push(n *Node, maxElems int) {
ix := sort.Search(len(h.entries), func(i int) bool {
return distcmp(h.target, h.entries[i].ID, n.ID) > 0
})
if len(h.entries) < maxElems {
h.entries = append(h.entries, n)
}
if ix == len(h.entries) {
// farther away than all nodes we already have.
// if there was room for it, the node is now the last element.
} else {
// slide existing entries down to make room
// this will overwrite the entry we just appended.
copy(h.entries[ix+1:], h.entries[ix:])
h.entries[ix] = n
}
}
package discover
import (
"crypto/ecdsa"
"errors"
"fmt"
"math/rand"
"net"
"reflect"
"testing"
"testing/quick"
"time"
"github.com/ethereum/go-ethereum/crypto"
)
func TestTable_bumpOrAddBucketAssign(t *testing.T) {
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
for i := 1; i < len(tab.buckets); i++ {
tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
}
for i, b := range tab.buckets {
if i > 0 && len(b.entries) != 1 {
t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
}
}
}
func TestTable_bumpOrAddPingReplace(t *testing.T) {
pingC := make(pingC)
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
last := fillBucket(tab, 200)
// this bumpOrAdd should not replace the last node
// because the node replies to ping.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
pinged := <-pingC
if pinged != last.ID {
t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
if l := len(tab.buckets[200].entries); l != bucketSize {
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
}
if !contains(tab.buckets[200].entries, last.ID) {
t.Error("last entry was removed")
}
if contains(tab.buckets[200].entries, new.ID) {
t.Error("new entry was added")
}
}
func TestTable_bumpOrAddPingTimeout(t *testing.T) {
tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
last := fillBucket(tab, 200)
// this bumpOrAdd should replace the last node
// because the node does not reply to ping.
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
// wait for async bucket update. damn. this needs to go away.
time.Sleep(2 * time.Millisecond)
tab.mutex.Lock()
defer tab.mutex.Unlock()
if l := len(tab.buckets[200].entries); l != bucketSize {
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
}
if contains(tab.buckets[200].entries, last.ID) {
t.Error("last entry was not removed")
}
if !contains(tab.buckets[200].entries, new.ID) {
t.Error("new entry was not added")
}
}
func fillBucket(tab *Table, ld int) (last *Node) {
b := tab.buckets[ld]
for len(b.entries) < bucketSize {
b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)})
}
return b.entries[bucketSize-1]
}
type pingC chan NodeID
func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
panic("findnode called on pingRecorder")
}
func (t pingC) close() {
panic("close called on pingRecorder")
}
func (t pingC) ping(n *Node) error {
if t == nil {
return errTimeout
}
t <- n.ID
return nil
}
func TestTable_bump(t *testing.T) {
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
// add an old entry and two recent ones
oldactive := time.Now().Add(-2 * time.Minute)
old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
others := []*Node{
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
}
tab.add(append(others, old))
if tab.buckets[200].entries[0] == old {
t.Fatal("old entry is at front of bucket")
}
// bumping the old entry should move it to the front
tab.bump(old.ID)
if old.active == oldactive {
t.Error("activity timestamp not updated")
}
if tab.buckets[200].entries[0] != old {
t.Errorf("bumped entry did not move to the front of bucket")
}
}
func TestTable_closest(t *testing.T) {
t.Parallel()
test := func(test *closeTest) bool {
// for any node table, Target and N
tab := newTable(nil, test.Self, &net.UDPAddr{})
tab.add(test.All)
// check that doClosest(Target, N) returns nodes
result := tab.closest(test.Target, test.N).entries
if hasDuplicates(result) {
t.Errorf("result contains duplicates")
return false
}
if !sortedByDistanceTo(test.Target, result) {
t.Errorf("result is not sorted by distance to target")
return false
}
// check that the number of results is min(N, tablen)
wantN := test.N
if tlen := tab.len(); tlen < test.N {
wantN = tlen
}
if len(result) != wantN {
t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN)
return false
} else if len(result) == 0 {
return true // no need to check distance
}
// check that the result nodes have minimum distance to target.
for _, b := range tab.buckets {
for _, n := range b.entries {
if contains(result, n.ID) {
continue // don't run the check below for nodes in result
}
farthestResult := result[len(result)-1].ID
if distcmp(test.Target, n.ID, farthestResult) < 0 {
t.Errorf("table contains node that is closer to target but it's not in result")
t.Logf(" Target: %v", test.Target)
t.Logf(" Farthest Result: %v", farthestResult)
t.Logf(" ID: %v", n.ID)
return false
}
}
}
return true
}
if err := quick.Check(test, quickcfg); err != nil {
t.Error(err)
}
}
type closeTest struct {
Self NodeID
Target NodeID
All []*Node
N int
}
func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
t := &closeTest{
Self: gen(NodeID{}, rand).(NodeID),
Target: gen(NodeID{}, rand).(NodeID),
N: rand.Intn(bucketSize),
}
for _, id := range gen([]NodeID{}, rand).([]NodeID) {
t.All = append(t.All, &Node{ID: id})
}
return reflect.ValueOf(t)
}
func TestTable_Lookup(t *testing.T) {
self := gen(NodeID{}, quickrand).(NodeID)
target := randomID(self, 200)
transport := findnodeOracle{t, target}
tab := newTable(transport, self, &net.UDPAddr{})
// lookup on empty table returns no nodes
if results := tab.Lookup(target); len(results) > 0 {
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
}
// seed table with initial node (otherwise lookup will terminate immediately)
tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
results := tab.Lookup(target)
t.Logf("results:")
for _, e := range results {
t.Logf(" ld=%d, %v", logdist(target, e.ID), e.ID)
}
if len(results) != bucketSize {
t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize)
}
if hasDuplicates(results) {
t.Errorf("result set contains duplicate entries")
}
if !sortedByDistanceTo(target, results) {
t.Errorf("result set not sorted by distance to target")
}
if !contains(results, target) {
t.Errorf("result set does not contain target")
}
}
// findnode on this transport always returns at least one node
// that is one bucket closer to the target.
type findnodeOracle struct {
t *testing.T
target NodeID
}
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
t.t.Logf("findnode query at dist %d", n.DiscPort)
// current log distance is encoded in port number
var result []*Node
switch n.DiscPort {
case 0:
panic("query to node at distance 0")
default:
// TODO: add more randomness to distances
next := n.DiscPort - 1
for i := 0; i < bucketSize; i++ {
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
}
}
return result, nil
}
func (t findnodeOracle) close() {}
func (t findnodeOracle) ping(n *Node) error {
return errors.New("ping is not supported by this transport")
}
func hasDuplicates(slice []*Node) bool {
seen := make(map[NodeID]bool)
for _, e := range slice {
if seen[e.ID] {
return true
}
seen[e.ID] = true
}
return false
}
func sortedByDistanceTo(distbase NodeID, slice []*Node) bool {
var last NodeID
for i, e := range slice {
if i > 0 && distcmp(distbase, e.ID, last) < 0 {
return false
}
last = e.ID
}
return true
}
func contains(ns []*Node, id NodeID) bool {
for _, n := range ns {
if n.ID == id {
return true
}
}
return false
}
// gen wraps quick.Value so it's easier to use.
// it generates a random value of the given value's type.
func gen(typ interface{}, rand *rand.Rand) interface{} {
v, ok := quick.Value(reflect.TypeOf(typ), rand)
if !ok {
panic(fmt.Sprintf("couldn't generate random value of type %T", typ))
}
return v.Interface()
}
func newkey() *ecdsa.PrivateKey {
key, err := crypto.GenerateKey()
if err != nil {
panic("couldn't generate key: " + err.Error())
}
return key
}
This diff is collapsed.
package discover
import (
"fmt"
logpkg "log"
"net"
"os"
"testing"
"time"
"github.com/ethereum/go-ethereum/logger"
)
func init() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
}
func TestUDP_ping(t *testing.T) {
t.Parallel()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
defer n1.Close()
defer n2.Close()
if err := n1.net.ping(n2.self); err != nil {
t.Fatalf("ping error: %v", err)
}
if find(n2, n1.self.ID) == nil {
t.Errorf("node 2 does not contain id of node 1")
}
if e := find(n1, n2.self.ID); e != nil {
t.Errorf("node 1 does contains id of node 2: %v", e)
}
}
func find(tab *Table, id NodeID) *Node {
for _, b := range tab.buckets {
for _, e := range b.entries {
if e.ID == id {
return e
}
}
}
return nil
}
func TestUDP_findnode(t *testing.T) {
t.Parallel()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
defer n1.Close()
defer n2.Close()
// put a few nodes into n2. the exact distribution shouldn't
// matter much, altough we need to take care not to overflow
// any bucket.
target := randomID(n1.self.ID, 100)
nodes := &nodesByDistance{target: target}
for i := 0; i < bucketSize; i++ {
n2.add([]*Node{&Node{
IP: net.IP{1, 2, 3, byte(i)},
DiscPort: i + 2,
TCPPort: i + 2,
ID: randomID(n2.self.ID, i+2),
}})
}
n2.add(nodes.entries)
n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
expected := n2.closest(target, bucketSize)
err := runUDP(10, func() error {
result, _ := n1.net.findnode(n2.self, target)
if len(result) != bucketSize {
return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
}
for i := range result {
if result[i].ID != expected.entries[i].ID {
return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
}
}
return nil
})
if err != nil {
t.Error(err)
}
}
func TestUDP_replytimeout(t *testing.T) {
t.Parallel()
// reserve a port so we don't talk to an existing service by accident
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
fd, err := net.ListenUDP("udp", addr)
if err != nil {
t.Fatal(err)
}
defer fd.Close()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
defer n1.Close()
n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
if err := n1.net.ping(n2); err != errTimeout {
t.Error("expected timeout error, got", err)
}
if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
t.Error("expected timeout error, got", err)
} else if len(result) > 0 {
t.Error("expected empty result, got", result)
}
}
func TestUDP_findnodeMultiReply(t *testing.T) {
t.Parallel()
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
udp2 := n2.net.(*udp)
defer n1.Close()
defer n2.Close()
err := runUDP(10, func() error {
nodes := make([]*Node, bucketSize)
for i := range nodes {
nodes[i] = &Node{
IP: net.IP{1, 2, 3, 4},
DiscPort: i + 1,
TCPPort: i + 1,
ID: randomID(n2.self.ID, i+1),
}
}
// ask N2 for neighbors. it will send an empty reply back.
// the request will wait for up to bucketSize replies.
resultc := make(chan []*Node)
errc := make(chan error)
go func() {
ns, err := n1.net.findnode(n2.self, n1.self.ID)
if err != nil {
errc <- err
} else {
resultc <- ns
}
}()
// send a few more neighbors packets to N1.
// it should collect those.
for end := 0; end < len(nodes); {
off := end
if end = end + 5; end > len(nodes) {
end = len(nodes)
}
udp2.send(n1.self, neighborsPacket, neighbors{
Nodes: nodes[off:end],
Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
})
}
// check that they are all returned. we cannot just check for
// equality because they might not be returned in the order they
// were sent.
var result []*Node
select {
case result = <-resultc:
case err := <-errc:
return err
}
if hasDuplicates(result) {
return fmt.Errorf("result slice contains duplicates")
}
if len(result) != len(nodes) {
return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
}
matched := make(map[NodeID]bool)
for _, n := range result {
for _, expn := range nodes {
if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
matched[n.ID] = true
}
}
}
if len(matched) != len(nodes) {
return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
}
return nil
})
if err != nil {
t.Error(err)
}
}
// runUDP runs a test n times and returns an error if the test failed
// in all n runs. This is necessary because UDP is unreliable even for
// connections on the local machine, causing test failures.
func runUDP(n int, test func() error) error {
errcount := 0
errors := ""
for i := 0; i < n; i++ {
if err := test(); err != nil {
errors += fmt.Sprintf("\n#%d: %v", i, err)
errcount++
}
}
if errcount == n {
return fmt.Errorf("failed on all %d iterations:%s", n, errors)
}
return nil
}
package p2p
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
......@@ -8,12 +9,37 @@ import (
"io"
"io/ioutil"
"math/big"
"net"
"sync"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/rlp"
)
// parameters for frameRW
const (
// maximum time allowed for reading a message header.
// this is effectively the amount of time a connection can be idle.
frameReadTimeout = 1 * time.Minute
// maximum time allowed for reading the payload data of a message.
// this is shorter than (and distinct from) frameReadTimeout because
// the connection is not considered idle while a message is transferred.
// this also limits the payload size of messages to how much the connection
// can transfer within the timeout.
payloadReadTimeout = 5 * time.Second
// maximum amount of time allowed for writing a complete message.
msgWriteTimeout = 5 * time.Second
// messages smaller than this many bytes will be read at
// once before passing them to a protocol. this increases
// concurrency in the processing.
wholePayloadSize = 64 * 1024
)
// Msg defines the structure of a p2p message.
//
// Note that a Msg can only be sent once since the Payload reader is
......@@ -74,11 +100,14 @@ type MsgWriter interface {
// WriteMsg sends a message. It will block until the message's
// Payload has been consumed by the other end.
//
// Note that messages can be sent only once.
// Note that messages can be sent only once because their
// payload reader is drained.
WriteMsg(Msg) error
}
// MsgReadWriter provides reading and writing of encoded messages.
// Implementations should ensure that ReadMsg and WriteMsg can be
// called simultaneously from multiple goroutines.
type MsgReadWriter interface {
MsgReader
MsgWriter
......@@ -90,8 +119,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
return w.WriteMsg(NewMsg(code, data...))
}
// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
// As required by the interface, ReadMsg and WriteMsg can be called from
// multiple goroutines.
type frameRW struct {
net.Conn // make Conn methods available. be careful.
bufconn *bufio.ReadWriter
// this channel is used to 'lend' bufconn to a caller of ReadMsg
// until the message payload has been consumed. the channel
// receives a value when EOF is reached on the payload, unblocking
// a pending call to ReadMsg.
rsync chan struct{}
// this mutex guards writes to bufconn.
writeMu sync.Mutex
}
func newFrameRW(conn net.Conn, timeout time.Duration) *frameRW {
rsync := make(chan struct{}, 1)
rsync <- struct{}{}
return &frameRW{
Conn: conn,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
rsync: rsync,
}
}
var magicToken = []byte{34, 64, 8, 145}
func (rw *frameRW) WriteMsg(msg Msg) error {
rw.writeMu.Lock()
defer rw.writeMu.Unlock()
rw.SetWriteDeadline(time.Now().Add(msgWriteTimeout))
if err := writeMsg(rw.bufconn, msg); err != nil {
return err
}
return rw.bufconn.Flush()
}
func writeMsg(w io.Writer, msg Msg) error {
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
code := ethutil.Encode(uint32(msg.Code))
......@@ -120,31 +186,51 @@ func makeListHeader(length uint32) []byte {
return append([]byte{lenb}, enc...)
}
// readMsg reads a message header from r.
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
func (rw *frameRW) ReadMsg() (msg Msg, err error) {
<-rw.rsync // wait until bufconn is ours
rw.SetReadDeadline(time.Now().Add(frameReadTimeout))
// read magic and payload size
start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil {
return msg, newPeerError(errRead, "%v", err)
if _, err = io.ReadFull(rw.bufconn, start); err != nil {
return msg, err
}
if !bytes.HasPrefix(start, magicToken) {
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
}
size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code
posr := &postrack{r, 0}
posr := &postrack{rw.bufconn, 0}
s := rlp.NewStream(posr)
if _, err := s.List(); err != nil {
return msg, err
}
code, err := s.Uint()
msg.Code, err = s.Uint()
if err != nil {
return msg, err
}
payloadsize := size - posr.p
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
msg.Size = size - posr.p
rw.SetReadDeadline(time.Now().Add(payloadReadTimeout))
if msg.Size <= wholePayloadSize {
// msg is small, read all of it and move on to the next message.
pbuf := make([]byte, msg.Size)
if _, err := io.ReadFull(rw.bufconn, pbuf); err != nil {
return msg, err
}
rw.rsync <- struct{}{} // bufconn is available again
msg.Payload = bytes.NewReader(pbuf)
} else {
// lend bufconn to the caller until it has
// consumed the payload. eofSignal will send a value
// on rw.rsync when EOF is reached.
pr := &eofSignal{rw.bufconn, msg.Size, rw.rsync}
msg.Payload = pr
}
return msg, nil
}
// postrack wraps an rlp.ByteReader with a position counter.
......@@ -167,6 +253,39 @@ func (r *postrack) ReadByte() (byte, error) {
return b, err
}
// eofSignal wraps a reader with eof signaling. the eof channel is
// closed when the wrapped reader returns an error or when count bytes
// have been read.
type eofSignal struct {
wrapped io.Reader
count uint32 // number of bytes left
eof chan<- struct{}
}
// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.
func (r *eofSignal) Read(buf []byte) (int, error) {
if r.count == 0 {
if r.eof != nil {
r.eof <- struct{}{}
r.eof = nil
}
return 0, io.EOF
}
max := len(buf)
if int(r.count) < len(buf) {
max = int(r.count)
}
n, err := r.wrapped.Read(buf[:max])
r.count -= uint32(n)
if (err != nil || r.count == 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
}
return n, err
}
// MsgPipe creates a message pipe. Reads on one end are matched
// with writes on the other. The pipe is full-duplex, both ends
// implement MsgReadWriter.
......@@ -198,7 +317,7 @@ type MsgPipeRW struct {
func (p *MsgPipeRW) WriteMsg(msg Msg) error {
if atomic.LoadInt32(p.closed) == 0 {
consumed := make(chan struct{}, 1)
msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed}
msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
select {
case p.w <- msg:
if msg.Size > 0 {
......
......@@ -3,12 +3,11 @@ package p2p
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"runtime"
"testing"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
func TestNewMsg(t *testing.T) {
......@@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) {
}
}
func TestEncodeDecodeMsg(t *testing.T) {
msg := NewMsg(3, 1, "000")
buf := new(bytes.Buffer)
if err := writeMsg(buf, msg); err != nil {
t.Fatalf("encodeMsg error: %v", err)
}
// t.Logf("encoded: %x", buf.Bytes())
// func TestEncodeDecodeMsg(t *testing.T) {
// msg := NewMsg(3, 1, "000")
// buf := new(bytes.Buffer)
// if err := writeMsg(buf, msg); err != nil {
// t.Fatalf("encodeMsg error: %v", err)
// }
// // t.Logf("encoded: %x", buf.Bytes())
decmsg, err := readMsg(buf)
if err != nil {
t.Fatalf("readMsg error: %v", err)
}
if decmsg.Code != 3 {
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
}
if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
}
// decmsg, err := readMsg(buf)
// if err != nil {
// t.Fatalf("readMsg error: %v", err)
// }
// if decmsg.Code != 3 {
// t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
// }
// if decmsg.Size != 5 {
// t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
// }
var data struct {
I uint
S string
}
if err := decmsg.Decode(&data); err != nil {
t.Fatalf("Decode error: %v", err)
}
if data.I != 1 {
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
}
if data.S != "000" {
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
}
}
// var data struct {
// I uint
// S string
// }
// if err := decmsg.Decode(&data); err != nil {
// t.Fatalf("Decode error: %v", err)
// }
// if data.I != 1 {
// t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
// }
// if data.S != "000" {
// t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
// }
// }
func TestDecodeRealMsg(t *testing.T) {
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
msg, err := readMsg(bytes.NewReader(data))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// func TestDecodeRealMsg(t *testing.T) {
// data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
// msg, err := readMsg(bytes.NewReader(data))
// if err != nil {
// t.Fatalf("unexpected error: %v", err)
// }
if msg.Code != 0 {
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
}
}
// if msg.Code != 0 {
// t.Errorf("incorrect code %d, want %d", msg.Code, 0)
// }
// }
func ExampleMsgPipe() {
rw1, rw2 := MsgPipe()
......@@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) {
go rw1.Close()
}
}
func TestEOFSignal(t *testing.T) {
rb := make([]byte, 10)
// empty reader
eof := make(chan struct{}, 1)
sig := &eofSignal{new(bytes.Buffer), 0, eof}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
}
}
package p2p
import (
"fmt"
"net"
)
func ParseNAT(natType string, gateway string) (nat NAT, err error) {
switch natType {
case "UPNP":
nat = UPNP()
case "PMP":
ip := net.ParseIP(gateway)
if ip == nil {
return nil, fmt.Errorf("cannot resolve PMP gateway IP %s", gateway)
}
nat = PMP(ip)
case "":
default:
return nil, fmt.Errorf("unrecognised NAT type '%s'", natType)
}
return
}
// Package nat provides access to common port mapping protocols.
package nat
import (
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/ethereum/go-ethereum/logger"
"github.com/jackpal/go-nat-pmp"
)
var log = logger.NewLogger("P2P NAT")
// An implementation of nat.Interface can map local ports to ports
// accessible from the Internet.
type Interface interface {
// These methods manage a mapping between a port on the local
// machine to a port that can be connected to from the internet.
//
// protocol is "UDP" or "TCP". Some implementations allow setting
// a display name for the mapping. The mapping may be removed by
// the gateway when its lifetime ends.
AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
DeleteMapping(protocol string, extport, intport int) error
// This method should return the external (Internet-facing)
// address of the gateway device.
ExternalIP() (net.IP, error)
// Should return name of the method. This is used for logging.
String() string
}
// Parse parses a NAT interface description.
// The following formats are currently accepted.
// Note that mechanism names are not case-sensitive.
//
// "" or "none" return nil
// "extip:77.12.33.4" will assume the local machine is reachable on the given IP
// "any" uses the first auto-detected mechanism
// "upnp" uses the Universal Plug and Play protocol
// "pmp" uses NAT-PMP with an auto-detected gateway address
// "pmp:192.168.0.1" uses NAT-PMP with the given gateway address
func Parse(spec string) (Interface, error) {
var (
parts = strings.SplitN(spec, ":", 2)
mech = strings.ToLower(parts[0])
ip net.IP
)
if len(parts) > 1 {
ip = net.ParseIP(parts[1])
if ip == nil {
return nil, errors.New("invalid IP address")
}
}
switch mech {
case "", "none", "off":
return nil, nil
case "any", "auto", "on":
return Any(), nil
case "extip", "ip":
if ip == nil {
return nil, errors.New("missing IP address")
}
return ExtIP(ip), nil
case "upnp":
return UPnP(), nil
case "pmp", "natpmp", "nat-pmp":
return PMP(ip), nil
default:
return nil, fmt.Errorf("unknown mechanism %q", parts[0])
}
}
const (
mapTimeout = 20 * time.Minute
mapUpdateInterval = 15 * time.Minute
)
// Map adds a port mapping on m and keeps it alive until c is closed.
// This function is typically invoked in its own goroutine.
func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) {
refresh := time.NewTimer(mapUpdateInterval)
defer func() {
refresh.Stop()
log.Debugf("Deleting port mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
m.DeleteMapping(protocol, extport, intport)
}()
log.Debugf("add mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
log.Errorf("mapping error: %v\n", err)
}
for {
select {
case _, ok := <-c:
if !ok {
return
}
case <-refresh.C:
log.DebugDetailf("refresh mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
log.Errorf("mapping error: %v\n", err)
}
refresh.Reset(mapUpdateInterval)
}
}
}
// ExtIP assumes that the local machine is reachable on the given
// external IP address, and that any required ports were mapped manually.
// Mapping operations will not return an error but won't actually do anything.
func ExtIP(ip net.IP) Interface {
if ip == nil {
panic("IP must not be nil")
}
return extIP(ip)
}
type extIP net.IP
func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
// These do nothing.
func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
func (extIP) DeleteMapping(string, int, int) error { return nil }
// Any returns a port mapper that tries to discover any supported
// mechanism on the local network.
func Any() Interface {
// TODO: attempt to discover whether the local machine has an
// Internet-class address. Return ExtIP in this case.
return startautodisc("UPnP or NAT-PMP", func() Interface {
found := make(chan Interface, 2)
go func() { found <- discoverUPnP() }()
go func() { found <- discoverPMP() }()
for i := 0; i < cap(found); i++ {
if c := <-found; c != nil {
return c
}
}
return nil
})
}
// UPnP returns a port mapper that uses UPnP. It will attempt to
// discover the address of your router using UDP broadcasts.
func UPnP() Interface {
return startautodisc("UPnP", discoverUPnP)
}
// PMP returns a port mapper that uses NAT-PMP. The provided gateway
// address should be the IP of your router. If the given gateway
// address is nil, PMP will attempt to auto-discover the router.
func PMP(gateway net.IP) Interface {
if gateway != nil {
return &pmp{gw: gateway, c: natpmp.NewClient(gateway)}
}
return startautodisc("NAT-PMP", discoverPMP)
}
// autodisc represents a port mapping mechanism that is still being
// auto-discovered. Calls to the Interface methods on this type will
// wait until the discovery is done and then call the method on the
// discovered mechanism.
//
// This type is useful because discovery can take a while but we
// want return an Interface value from UPnP, PMP and Auto immediately.
type autodisc struct {
what string
done <-chan Interface
mu sync.Mutex
found Interface
}
func startautodisc(what string, doit func() Interface) Interface {
// TODO: monitor network configuration and rerun doit when it changes.
done := make(chan Interface)
ad := &autodisc{what: what, done: done}
go func() { done <- doit(); close(done) }()
return ad
}
func (n *autodisc) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
if err := n.wait(); err != nil {
return err
}
return n.found.AddMapping(protocol, extport, intport, name, lifetime)
}
func (n *autodisc) DeleteMapping(protocol string, extport, intport int) error {
if err := n.wait(); err != nil {
return err
}
return n.found.DeleteMapping(protocol, extport, intport)
}
func (n *autodisc) ExternalIP() (net.IP, error) {
if err := n.wait(); err != nil {
return nil, err
}
return n.found.ExternalIP()
}
func (n *autodisc) String() string {
n.mu.Lock()
defer n.mu.Unlock()
if n.found == nil {
return n.what
} else {
return n.found.String()
}
}
func (n *autodisc) wait() error {
n.mu.Lock()
found := n.found
n.mu.Unlock()
if found != nil {
// already discovered
return nil
}
if found = <-n.done; found == nil {
return errors.New("no devices discovered")
}
n.mu.Lock()
n.found = found
n.mu.Unlock()
return nil
}
package nat
import (
"fmt"
"net"
"strings"
"time"
"github.com/jackpal/go-nat-pmp"
)
// natPMPClient adapts the NAT-PMP protocol implementation so it conforms to
// the common interface.
type pmp struct {
gw net.IP
c *natpmp.Client
}
func (n *pmp) String() string {
return fmt.Sprintf("NAT-PMP(%v)", n.gw)
}
func (n *pmp) ExternalIP() (net.IP, error) {
response, err := n.c.GetExternalAddress()
if err != nil {
return nil, err
}
return response.ExternalIPAddress[:], nil
}
func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
if lifetime <= 0 {
return fmt.Errorf("lifetime must not be <= 0")
}
// Note order of port arguments is switched between our
// AddMapping and the client's AddPortMapping.
_, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second))
return err
}
func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) {
// To destroy a mapping, send an add-port with an internalPort of
// the internal port to destroy, an external port of zero and a
// time of zero.
_, err = n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0)
return err
}
func discoverPMP() Interface {
// run external address lookups on all potential gateways
gws := potentialGateways()
found := make(chan *pmp, len(gws))
for i := range gws {
gw := gws[i]
go func() {
c := natpmp.NewClient(gw)
if _, err := c.GetExternalAddress(); err != nil {
found <- nil
} else {
found <- &pmp{gw, c}
}
}()
}
// return the one that responds first.
// discovery needs to be quick, so we stop caring about
// any responses after a very short timeout.
timeout := time.NewTimer(1 * time.Second)
defer timeout.Stop()
for _ = range gws {
select {
case c := <-found:
if c != nil {
return c
}
case <-timeout.C:
return nil
}
}
return nil
}
var (
// LAN IP ranges
_, lan10, _ = net.ParseCIDR("10.0.0.0/8")
_, lan176, _ = net.ParseCIDR("172.16.0.0/12")
_, lan192, _ = net.ParseCIDR("192.168.0.0/16")
)
// TODO: improve this. We currently assume that (on most networks)
// the router is X.X.X.1 in a local LAN range.
func potentialGateways() (gws []net.IP) {
ifaces, err := net.Interfaces()
if err != nil {
return nil
}
for _, iface := range ifaces {
ifaddrs, err := iface.Addrs()
if err != nil {
return gws
}
for _, addr := range ifaddrs {
switch x := addr.(type) {
case *net.IPNet:
if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) {
ip := x.IP.Mask(x.Mask).To4()
if ip != nil {
ip[3] = ip[3] | 0x01
gws = append(gws, ip)
}
}
}
}
}
return gws
}
package nat
import (
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/fjl/goupnp"
"github.com/fjl/goupnp/dcps/internetgateway1"
"github.com/fjl/goupnp/dcps/internetgateway2"
)
type upnp struct {
dev *goupnp.RootDevice
service string
client upnpClient
}
type upnpClient interface {
GetExternalIPAddress() (string, error)
AddPortMapping(string, uint16, string, uint16, string, bool, string, uint32) error
DeletePortMapping(string, uint16, string) error
GetNATRSIPStatus() (sip bool, nat bool, err error)
}
func (n *upnp) ExternalIP() (addr net.IP, err error) {
ipString, err := n.client.GetExternalIPAddress()
if err != nil {
return nil, err
}
ip := net.ParseIP(ipString)
if ip == nil {
return nil, errors.New("bad IP in response")
}
return ip, nil
}
func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, lifetime time.Duration) error {
ip, err := n.internalAddress()
if err != nil {
return nil
}
protocol = strings.ToUpper(protocol)
lifetimeS := uint32(lifetime / time.Second)
return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS)
}
func (n *upnp) internalAddress() (net.IP, error) {
devaddr, err := net.ResolveUDPAddr("udp4", n.dev.URLBase.Host)
if err != nil {
return nil, err
}
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
switch x := addr.(type) {
case *net.IPNet:
if x.Contains(devaddr.IP) {
return x.IP, nil
}
}
}
}
return nil, fmt.Errorf("could not find local address in same net as %v", devaddr)
}
func (n *upnp) DeleteMapping(protocol string, extport, intport int) error {
return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol))
}
func (n *upnp) String() string {
return "UPNP " + n.service
}
// discoverUPnP searches for Internet Gateway Devices
// and returns the first one it can find on the local network.
func discoverUPnP() Interface {
found := make(chan *upnp, 2)
// IGDv1
go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
switch sc.Service.ServiceType {
case internetgateway1.URN_WANIPConnection_1:
return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{sc}}
case internetgateway1.URN_WANPPPConnection_1:
return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{sc}}
}
return nil
})
// IGDv2
go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
switch sc.Service.ServiceType {
case internetgateway2.URN_WANIPConnection_1:
return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{sc}}
case internetgateway2.URN_WANIPConnection_2:
return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{sc}}
case internetgateway2.URN_WANPPPConnection_1:
return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{sc}}
}
return nil
})
for i := 0; i < cap(found); i++ {
if c := <-found; c != nil {
return c
}
}
return nil
}
func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) {
devs, err := goupnp.DiscoverDevices(target)
if err != nil {
return
}
found := false
for i := 0; i < len(devs) && !found; i++ {
if devs[i].Root == nil {
continue
}
devs[i].Root.Device.VisitServices(func(service *goupnp.Service) {
if found {
return
}
// check for a matching IGD service
sc := goupnp.ServiceClient{service.NewSOAPClient(), devs[i].Root, service}
upnp := matcher(devs[i].Root, sc)
if upnp == nil {
return
}
// check whether port mapping is enabled
if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat {
return
}
out <- upnp
found = true
})
}
if !found {
out <- nil
}
}
package p2p
import (
"fmt"
"net"
"time"
natpmp "github.com/jackpal/go-nat-pmp"
)
// Adapt the NAT-PMP protocol to the NAT interface
// TODO:
// + Register for changes to the external address.
// + Re-register port mapping when router reboots.
// + A mechanism for keeping a port mapping registered.
// + Discover gateway address automatically.
type natPMPClient struct {
client *natpmp.Client
}
// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
// address should be the IP of your router.
func PMP(gateway net.IP) (nat NAT) {
return &natPMPClient{natpmp.NewClient(gateway)}
}
func (*natPMPClient) String() string {
return "NAT-PMP"
}
func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
response, err := n.client.GetExternalAddress()
if err != nil {
return nil, err
}
return response.ExternalIPAddress[:], nil
}
func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
if lifetime <= 0 {
return fmt.Errorf("lifetime must not be <= 0")
}
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
_, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
return err
}
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
// To destroy a mapping, send an add-port with
// an internalPort of the internal port to destroy, an external port of zero and a time of zero.
_, err = n.client.AddPortMapping(protocol, internalPort, 0, 0)
return
}
package p2p
// Just enough UPnP to be able to forward ports
//
import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
)
const (
upnpDiscoverAttempts = 3
upnpDiscoverTimeout = 5 * time.Second
)
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
// discover the address of your router using UDP broadcasts.
func UPNP() NAT {
return &upnpNAT{}
}
type upnpNAT struct {
serviceURL string
ourIP string
}
func (n *upnpNAT) String() string {
return "UPNP"
}
func (n *upnpNAT) discover() error {
if n.serviceURL != "" {
// already discovered
return nil
}
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
if err != nil {
return err
}
// TODO: try on all network interfaces simultaneously.
// Broadcasting on 0.0.0.0 could select a random interface
// to send on (platform specific).
conn, err := net.ListenPacket("udp4", ":0")
if err != nil {
return err
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
buf := bytes.NewBufferString(
"M-SEARCH * HTTP/1.1\r\n" +
"HOST: 239.255.255.250:1900\r\n" +
st +
"MAN: \"ssdp:discover\"\r\n" +
"MX: 2\r\n\r\n")
message := buf.Bytes()
answerBytes := make([]byte, 1024)
for i := 0; i < upnpDiscoverAttempts; i++ {
_, err = conn.WriteTo(message, ssdp)
if err != nil {
return err
}
nn, _, err := conn.ReadFrom(answerBytes)
if err != nil {
continue
}
answer := string(answerBytes[0:nn])
if strings.Index(answer, "\r\n"+st) < 0 {
continue
}
// HTTP header field names are case-insensitive.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
locString := "\r\nlocation: "
answer = strings.ToLower(answer)
locIndex := strings.Index(answer, locString)
if locIndex < 0 {
continue
}
loc := answer[locIndex+len(locString):]
endIndex := strings.Index(loc, "\r\n")
if endIndex < 0 {
continue
}
locURL := loc[0:endIndex]
var serviceURL string
serviceURL, err = getServiceURL(locURL)
if err != nil {
return err
}
var ourIP string
ourIP, err = getOurIP()
if err != nil {
return err
}
n.serviceURL = serviceURL
n.ourIP = ourIP
return nil
}
return errors.New("UPnP port discovery failed.")
}
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
if err := n.discover(); err != nil {
return nil, err
}
info, err := n.getStatusInfo()
return net.ParseIP(info.externalIpAddress), err
}
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
if err := n.discover(); err != nil {
return err
}
// A single concatenation would break ARM compilation.
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
message += description +
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
"</NewLeaseDuration></u:AddPortMapping>"
// TODO: check response to see if the port was forwarded
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
return err
}
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
if err := n.discover(); err != nil {
return err
}
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
"</u:DeletePortMapping>"
// TODO: check response to see if the port was deleted
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
return err
}
type statusInfo struct {
externalIpAddress string
}
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
"</u:GetStatusInfo>"
var response *http.Response
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
if err != nil {
return
}
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
response.Body.Close()
return
}
// service represents the Service type in an UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type service struct {
ServiceType string `xml:"serviceType"`
ControlURL string `xml:"controlURL"`
}
// deviceList represents the deviceList type in an UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type deviceList struct {
XMLName xml.Name `xml:"deviceList"`
Device []device `xml:"device"`
}
// serviceList represents the serviceList type in an UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type serviceList struct {
XMLName xml.Name `xml:"serviceList"`
Service []service `xml:"service"`
}
// device represents the device type in an UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type device struct {
XMLName xml.Name `xml:"device"`
DeviceType string `xml:"deviceType"`
DeviceList deviceList `xml:"deviceList"`
ServiceList serviceList `xml:"serviceList"`
}
// specVersion represents the specVersion in a UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type specVersion struct {
XMLName xml.Name `xml:"specVersion"`
Major int `xml:"major"`
Minor int `xml:"minor"`
}
// root represents the Root document for a UPnP xml description.
// Only the parts we care about are present and thus the xml may have more
// fields than present in the structure.
type root struct {
XMLName xml.Name `xml:"root"`
SpecVersion specVersion
Device device
}
func getChildDevice(d *device, deviceType string) *device {
dl := d.DeviceList.Device
for i := 0; i < len(dl); i++ {
if dl[i].DeviceType == deviceType {
return &dl[i]
}
}
return nil
}
func getChildService(d *device, serviceType string) *service {
sl := d.ServiceList.Service
for i := 0; i < len(sl); i++ {
if sl[i].ServiceType == serviceType {
return &sl[i]
}
}
return nil
}
func getOurIP() (ip string, err error) {
hostname, err := os.Hostname()
if err != nil {
return
}
p, err := net.LookupIP(hostname)
if err != nil && len(p) > 0 {
return
}
return p[0].String(), nil
}
func getServiceURL(rootURL string) (url string, err error) {
r, err := http.Get(rootURL)
if err != nil {
return
}
defer r.Body.Close()
if r.StatusCode >= 400 {
err = errors.New(string(r.StatusCode))
return
}
var root root
err = xml.NewDecoder(r.Body).Decode(&root)
if err != nil {
return
}
a := &root.Device
if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" {
err = errors.New("No InternetGatewayDevice")
return
}
b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1")
if b == nil {
err = errors.New("No WANDevice")
return
}
c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1")
if c == nil {
err = errors.New("No WANConnectionDevice")
return
}
d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1")
if d == nil {
err = errors.New("No WANIPConnection")
return
}
url = combineURL(rootURL, d.ControlURL)
return
}
func combineURL(rootURL, subURL string) string {
protocolEnd := "://"
protoEndIndex := strings.Index(rootURL, protocolEnd)
a := rootURL[protoEndIndex+len(protocolEnd):]
rootIndex := strings.Index(a, "/")
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
}
func soapRequest(url, function, message string) (r *http.Response, err error) {
fullMessage := "<?xml version=\"1.0\" ?>" +
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
"<s:Body>" + message + "</s:Body></s:Envelope>"
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
if err != nil {
return
}
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
//req.Header.Set("Transfer-Encoding", "chunked")
req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"")
req.Header.Set("Connection", "Close")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Pragma", "no-cache")
r, err = http.DefaultClient.Do(req)
if err != nil {
return
}
if r.Body != nil {
defer r.Body.Close()
}
if r.StatusCode >= 400 {
// log.Stderr(function, r.StatusCode)
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
r = nil
return
}
return
}
This diff is collapsed.
......@@ -12,7 +12,6 @@ const (
errInvalidMsgCode
errInvalidMsg
errP2PVersionMismatch
errPubkeyMissing
errPubkeyInvalid
errPubkeyForbidden
errProtocolBreach
......@@ -22,20 +21,19 @@ const (
)
var errorToString = map[int]string{
errMagicTokenMismatch: "Magic token mismatch",
errRead: "Read error",
errWrite: "Write error",
errMisc: "Misc error",
errInvalidMsgCode: "Invalid message code",
errInvalidMsg: "Invalid message",
errMagicTokenMismatch: "magic token mismatch",
errRead: "read error",
errWrite: "write error",
errMisc: "misc error",
errInvalidMsgCode: "invalid message code",
errInvalidMsg: "invalid message",
errP2PVersionMismatch: "P2P Version Mismatch",
errPubkeyMissing: "Public key missing",
errPubkeyInvalid: "Public key invalid",
errPubkeyForbidden: "Public key forbidden",
errProtocolBreach: "Protocol Breach",
errPingTimeout: "Ping timeout",
errInvalidNetworkId: "Invalid network id",
errInvalidProtocolVersion: "Invalid protocol version",
errPubkeyInvalid: "public key invalid",
errPubkeyForbidden: "public key forbidden",
errProtocolBreach: "protocol Breach",
errPingTimeout: "ping timeout",
errInvalidNetworkId: "invalid network id",
errInvalidProtocolVersion: "invalid protocol version",
}
type peerError struct {
......@@ -62,22 +60,22 @@ func (self *peerError) Error() string {
type DiscReason byte
const (
DiscRequested DiscReason = 0x00
DiscNetworkError = 0x01
DiscProtocolError = 0x02
DiscUselessPeer = 0x03
DiscTooManyPeers = 0x04
DiscAlreadyConnected = 0x05
DiscIncompatibleVersion = 0x06
DiscInvalidIdentity = 0x07
DiscQuitting = 0x08
DiscUnexpectedIdentity = 0x09
DiscSelf = 0x0a
DiscReadTimeout = 0x0b
DiscSubprotocolError = 0x10
DiscRequested DiscReason = iota
DiscNetworkError
DiscProtocolError
DiscUselessPeer
DiscTooManyPeers
DiscAlreadyConnected
DiscIncompatibleVersion
DiscInvalidIdentity
DiscQuitting
DiscUnexpectedIdentity
DiscSelf
DiscReadTimeout
DiscSubprotocolError
)
var discReasonToString = [DiscSubprotocolError + 1]string{
var discReasonToString = [...]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
......@@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason {
switch peerError.Code {
case errP2PVersionMismatch:
return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid:
case errPubkeyInvalid:
return DiscInvalidIdentity
case errPubkeyForbidden:
return DiscUselessPeer
......@@ -125,7 +123,7 @@ func discReasonForError(err error) DiscReason {
return DiscProtocolError
case errPingTimeout:
return DiscReadTimeout
case errRead, errWrite, errMisc:
case errRead, errWrite:
return DiscNetworkError
default:
return DiscSubprotocolError
......
This diff is collapsed.
package p2p
import (
"bytes"
"time"
)
// Protocol represents a P2P subprotocol implementation.
type Protocol struct {
// Name should contain the official protocol name,
......@@ -32,38 +27,6 @@ func (p Protocol) cap() Cap {
return Cap{p.Name, p.Version}
}
const (
baseProtocolVersion = 2
baseProtocolLength = uint64(16)
baseProtocolMaxMsgSize = 10 * 1024 * 1024
)
const (
// devp2p message codes
handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
)
// handshake is the structure of a handshake list.
type handshake struct {
Version uint64
ID string
Caps []Cap
ListenPort uint64
NodeID []byte
}
func (h *handshake) String() string {
return h.ID
}
func (h *handshake) Pubkey() []byte {
return h.NodeID
}
// Cap is the structure of a peer capability.
type Cap struct {
Name string
......@@ -79,210 +42,3 @@ type capsByName []Cap
func (cs capsByName) Len() int { return len(cs) }
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
type baseProtocol struct {
rw MsgReadWriter
peer *Peer
}
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer}
errc := make(chan error, 1)
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
if err := bp.readHandshake(); err != nil {
return err
}
// handle write error
if err := <-errc; err != nil {
return err
}
// run main loop
go func() {
for {
if err := bp.handle(rw); err != nil {
errc <- err
break
}
}
}()
return bp.loop(errc)
}
var pingTimeout = 2 * time.Second
func (bp *baseProtocol) loop(quit <-chan error) error {
ping := time.NewTimer(pingTimeout)
activity := bp.peer.activity.Subscribe(time.Time{})
lastActive := time.Time{}
defer ping.Stop()
defer activity.Unsubscribe()
getPeersTick := time.NewTicker(10 * time.Second)
defer getPeersTick.Stop()
err := EncodeMsg(bp.rw, getPeersMsg)
for err == nil {
select {
case err = <-quit:
return err
case <-getPeersTick.C:
err = EncodeMsg(bp.rw, getPeersMsg)
case event := <-activity.Chan():
ping.Reset(pingTimeout)
lastActive = event.(time.Time)
case t := <-ping.C:
if lastActive.Add(pingTimeout * 2).Before(t) {
err = newPeerError(errPingTimeout, "")
} else if lastActive.Add(pingTimeout).Before(t) {
err = EncodeMsg(bp.rw, pingMsg)
}
}
}
return err
}
func (bp *baseProtocol) handle(rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
// make sure that the payload has been fully consumed
defer msg.Discard()
switch msg.Code {
case handshakeMsg:
return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg:
var reason [1]DiscReason
if err := msg.Decode(&reason); err != nil {
return err
}
return discRequestedError(reason[0])
case pingMsg:
return EncodeMsg(bp.rw, pongMsg)
case pongMsg:
case getPeersMsg:
peers := bp.peerList()
// this is dangerous. the spec says that we should _delay_
// sending the response if no new information is available.
// this means that would need to send a response later when
// new peers become available.
//
// TODO: add event mechanism to notify baseProtocol for new peers
if len(peers) > 0 {
return EncodeMsg(bp.rw, peersMsg, peers...)
}
case peersMsg:
var peers []*peerAddr
if err := msg.Decode(&peers); err != nil {
return err
}
for _, addr := range peers {
bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
}
default:
return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
}
return nil
}
func (bp *baseProtocol) readHandshake() error {
// read and handle remote handshake
msg, err := bp.rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if len(hs.NodeID) == 0 {
return newPeerError(errPubkeyMissing, "")
}
if len(hs.NodeID) != 64 {
return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
}
if da := bp.peer.dialAddr; da != nil {
// verify that the peer we wanted to connect to
// actually holds the target public key.
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
}
}
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err)
}
// TODO: remove Caps with empty name
var addr *peerAddr
if hs.ListenPort != 0 {
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
addr.Port = hs.ListenPort
}
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
bp.peer.startSubprotocols(hs.Caps)
return nil
}
func (bp *baseProtocol) handshakeMsg() Msg {
var (
port uint64
caps []interface{}
)
if bp.peer.ourListenAddr != nil {
port = bp.peer.ourListenAddr.Port
}
for _, proto := range bp.peer.protocols {
caps = append(caps, proto.cap())
}
return NewMsg(handshakeMsg,
baseProtocolVersion,
bp.peer.ourID.String(),
caps,
port,
bp.peer.ourID.Pubkey()[1:],
)
}
func (bp *baseProtocol) peerList() []interface{} {
peers := bp.peer.otherPeers()
ds := make([]interface{}, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger {
return l
}
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
......
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}
......@@ -350,8 +350,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
return writeUint, nil
case kind == reflect.String:
return writeString, nil
case kind == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 && !typ.Elem().Implements(encoderInterface):
case kind == reflect.Slice && isByte(typ.Elem()):
return writeBytes, nil
case kind == reflect.Array && isByte(typ.Elem()):
return writeByteArray, nil
case kind == reflect.Slice || kind == reflect.Array:
return makeSliceWriter(typ)
case kind == reflect.Struct:
......@@ -363,6 +365,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
}
}
func isByte(typ reflect.Type) bool {
return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface)
}
func writeUint(val reflect.Value, w *encbuf) error {
i := val.Uint()
if i == 0 {
......@@ -407,6 +413,20 @@ func writeBytes(val reflect.Value, w *encbuf) error {
return nil
}
func writeByteArray(val reflect.Value, w *encbuf) error {
if !val.CanAddr() {
// Slice requires the value to be addressable.
// Make it addressable by copying.
copy := reflect.New(val.Type()).Elem()
copy.Set(val)
val = copy
}
size := val.Len()
slice := val.Slice(0, size).Bytes()
w.encodeString(slice)
return nil
}
func writeString(val reflect.Value, w *encbuf) error {
s := val.String()
w.encodeStringHeader(len(s))
......
This diff is collapsed.
......@@ -215,7 +215,7 @@ func NewPeer(peer *p2p.Peer) *Peer {
return &Peer{
ref: peer,
Ip: fmt.Sprintf("%v", peer.RemoteAddr()),
Version: fmt.Sprintf("%v", peer.Identity()),
Version: fmt.Sprintf("%v", peer.ID()),
Caps: fmt.Sprintf("%v", caps),
}
}
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment