Unverified Commit 4eb92969 authored by Felföldi Zsolt's avatar Felföldi Zsolt Committed by GitHub

p2p/nodestate: ensure correct callback order (#21436)

This PR adds an extra guarantee to NodeStateMachine: it ensures that all
immediate effects of a certain change are processed before any subsequent
effects of any of the immediate effects on the same node. In the original
version, if a cascaded change caused a subscription callback to be called
multiple times for the same node then these calls might have happened in a
wrong chronological order.

For example:

- a subscription to flag0 changes flag1 and flag2
- a subscription to flag1 changes flag3
- a subscription to flag1, flag2 and flag3 was called in the following order:

   [flag1] -> [flag1, flag3]
   [] -> [flag1]
   [flag1, flag3] -> [flag1, flag2, flag3]

This happened because the tree of changes was traversed in a "depth-first
order". Now it is traversed in a "breadth-first order"; each node has a
FIFO queue for pending callbacks and each triggered subscription callback
is added to the end of the list. The already existing guarantees are
retained; no SetState or SetField returns until the callback queue of the
node is empty again. Just like before, it is the responsibility of the
state machine design to ensure that infinite state loops are not possible.
Multiple changes affecting the same node can still happen simultaneously;
in this case the changes can be interleaved in the FIFO of the node but the
correct order is still guaranteed.

A new unit test is also added to verify callback order in the above scenario.
parent a99ac533
......@@ -166,7 +166,7 @@ func newServerPool(db ethdb.KeyValueStore, dbKey []byte, vt *lpc.ValueTracker, d
if oldState.Equals(sfWaitDialTimeout) && newState.IsEmpty() {
// dial timeout, no connection
s.setRedialWait(n, dialCost, dialWaitStep)
s.ns.SetState(n, nodestate.Flags{}, sfDialing, 0)
s.ns.SetStateSub(n, nodestate.Flags{}, sfDialing, 0)
}
})
......@@ -193,10 +193,10 @@ func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enod
if rand.Intn(maxQueryFails*2) < int(fails) {
// skip pre-negotiation with increasing chance, max 50%
// this ensures that the client can operate even if UDP is not working at all
s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10)
s.ns.SetStateSub(n, sfCanDial, nodestate.Flags{}, time.Second*10)
// set canDial before resetting queried so that FillSet will not read more
// candidates unnecessarily
s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0)
s.ns.SetStateSub(n, nodestate.Flags{}, sfQueried, 0)
return
}
go func() {
......@@ -206,12 +206,15 @@ func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enod
} else {
atomic.StoreUint32(&s.queryFails, 0)
}
if q == 1 {
s.ns.SetState(n, sfCanDial, nodestate.Flags{}, time.Second*10)
} else {
s.setRedialWait(n, queryCost, queryWaitStep)
}
s.ns.SetState(n, nodestate.Flags{}, sfQueried, 0)
s.ns.Operation(func() {
// we are no longer running in the operation that the callback belongs to, start a new one because of setRedialWait
if q == 1 {
s.ns.SetStateSub(n, sfCanDial, nodestate.Flags{}, time.Second*10)
} else {
s.setRedialWait(n, queryCost, queryWaitStep)
}
s.ns.SetStateSub(n, nodestate.Flags{}, sfQueried, 0)
})
}()
}
})
......@@ -240,18 +243,20 @@ func (s *serverPool) start() {
}
}
unixTime := s.unixTime()
s.ns.ForEach(sfHasValue, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
s.calculateWeight(node)
if n, ok := s.ns.GetField(node, sfiNodeHistory).(nodeHistory); ok && n.redialWaitEnd > unixTime {
wait := n.redialWaitEnd - unixTime
lastWait := n.redialWaitEnd - n.redialWaitStart
if wait > lastWait {
// if the time until expiration is larger than the last suggested
// waiting time then the system clock was probably adjusted
wait = lastWait
s.ns.Operation(func() {
s.ns.ForEach(sfHasValue, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
s.calculateWeight(node)
if n, ok := s.ns.GetField(node, sfiNodeHistory).(nodeHistory); ok && n.redialWaitEnd > unixTime {
wait := n.redialWaitEnd - unixTime
lastWait := n.redialWaitEnd - n.redialWaitStart
if wait > lastWait {
// if the time until expiration is larger than the last suggested
// waiting time then the system clock was probably adjusted
wait = lastWait
}
s.ns.SetStateSub(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second)
}
s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, time.Duration(wait)*time.Second)
}
})
})
}
......@@ -261,9 +266,11 @@ func (s *serverPool) stop() {
if s.fillSet != nil {
s.fillSet.Close()
}
s.ns.ForEach(sfConnected, nodestate.Flags{}, func(n *enode.Node, state nodestate.Flags) {
// recalculate weight of connected nodes in order to update hasValue flag if necessary
s.calculateWeight(n)
s.ns.Operation(func() {
s.ns.ForEach(sfConnected, nodestate.Flags{}, func(n *enode.Node, state nodestate.Flags) {
// recalculate weight of connected nodes in order to update hasValue flag if necessary
s.calculateWeight(n)
})
})
s.ns.Stop()
}
......@@ -279,9 +286,11 @@ func (s *serverPool) registerPeer(p *serverPeer) {
// unregisterPeer implements serverPeerSubscriber
func (s *serverPool) unregisterPeer(p *serverPeer) {
s.setRedialWait(p.Node(), dialCost, dialWaitStep)
s.ns.SetState(p.Node(), nodestate.Flags{}, sfConnected, 0)
s.ns.SetField(p.Node(), sfiConnectedStats, nil)
s.ns.Operation(func() {
s.setRedialWait(p.Node(), dialCost, dialWaitStep)
s.ns.SetStateSub(p.Node(), nodestate.Flags{}, sfConnected, 0)
s.ns.SetFieldSub(p.Node(), sfiConnectedStats, nil)
})
s.vt.Unregister(p.ID())
p.setValueTracker(nil, nil)
}
......@@ -380,14 +389,16 @@ func (s *serverPool) serviceValue(node *enode.Node) (sessionValue, totalValue fl
// updateWeight calculates the node weight and updates the nodeWeight field and the
// hasValue flag. It also saves the node state if necessary.
// Note: this function should run inside a NodeStateMachine operation
func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDialCost uint64) {
weight := uint64(totalValue * nodeWeightMul / float64(totalDialCost))
if weight >= nodeWeightThreshold {
s.ns.SetState(node, sfHasValue, nodestate.Flags{}, 0)
s.ns.SetField(node, sfiNodeWeight, weight)
s.ns.SetStateSub(node, sfHasValue, nodestate.Flags{}, 0)
s.ns.SetFieldSub(node, sfiNodeWeight, weight)
} else {
s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0)
s.ns.SetField(node, sfiNodeWeight, nil)
s.ns.SetStateSub(node, nodestate.Flags{}, sfHasValue, 0)
s.ns.SetFieldSub(node, sfiNodeWeight, nil)
s.ns.SetFieldSub(node, sfiNodeHistory, nil)
}
s.ns.Persist(node) // saved if node history or hasValue changed
}
......@@ -400,6 +411,7 @@ func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDia
// a significant amount of service value again its waiting time is quickly reduced or reset
// to the minimum.
// Note: node weight is also recalculated and updated by this function.
// Note 2: this function should run inside a NodeStateMachine operation
func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep float64) {
n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory)
sessionValue, totalValue := s.serviceValue(node)
......@@ -450,21 +462,22 @@ func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep
if wait < waitThreshold {
n.redialWaitStart = unixTime
n.redialWaitEnd = unixTime + int64(nextTimeout)
s.ns.SetField(node, sfiNodeHistory, n)
s.ns.SetState(node, sfRedialWait, nodestate.Flags{}, wait)
s.ns.SetFieldSub(node, sfiNodeHistory, n)
s.ns.SetStateSub(node, sfRedialWait, nodestate.Flags{}, wait)
s.updateWeight(node, totalValue, totalDialCost)
} else {
// discard known node statistics if waiting time is very long because the node
// hasn't been responsive for a very long time
s.ns.SetField(node, sfiNodeHistory, nil)
s.ns.SetField(node, sfiNodeWeight, nil)
s.ns.SetState(node, nodestate.Flags{}, sfHasValue, 0)
s.ns.SetFieldSub(node, sfiNodeHistory, nil)
s.ns.SetFieldSub(node, sfiNodeWeight, nil)
s.ns.SetStateSub(node, nodestate.Flags{}, sfHasValue, 0)
}
}
// calculateWeight calculates and sets the node weight without altering the node history.
// This function should be called during startup and shutdown only, otherwise setRedialWait
// will keep the weights updated as the underlying statistics are adjusted.
// Note: this function should run inside a NodeStateMachine operation
func (s *serverPool) calculateWeight(node *enode.Node) {
n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory)
_, totalValue := s.serviceValue(node)
......
This diff is collapsed.
......@@ -147,8 +147,13 @@ func TestSetField(t *testing.T) {
// Set field before setting state
ns.SetField(testNode(1), fields[0], "hello world")
field := ns.GetField(testNode(1), fields[0])
if field == nil {
t.Fatalf("Field should be set before setting states")
}
ns.SetField(testNode(1), fields[0], nil)
field = ns.GetField(testNode(1), fields[0])
if field != nil {
t.Fatalf("Field shouldn't be set before setting states")
t.Fatalf("Field should be unset")
}
// Set field after setting state
ns.SetState(testNode(1), flags[0], Flags{}, 0)
......@@ -169,23 +174,6 @@ func TestSetField(t *testing.T) {
}
}
func TestUnsetField(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, fields := testSetup([]bool{false}, []reflect.Type{reflect.TypeOf("")})
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
ns.Start()
ns.SetState(testNode(1), flags[0], Flags{}, time.Second)
ns.SetField(testNode(1), fields[0], "hello world")
ns.SetState(testNode(1), Flags{}, flags[0], 0)
if field := ns.GetField(testNode(1), fields[0]); field != nil {
t.Fatalf("Field should be unset")
}
}
func TestSetState(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
......@@ -339,6 +327,7 @@ func TestFieldSub(t *testing.T) {
ns2.Start()
check(s.OfflineFlag(), nil, uint64(100))
ns2.SetState(testNode(1), Flags{}, flags[0], 0)
ns2.SetField(testNode(1), fields[0], nil)
check(Flags{}, uint64(100), nil)
ns2.Stop()
}
......@@ -387,3 +376,34 @@ func TestDuplicatedFlags(t *testing.T) {
clock.Run(2 * time.Second)
check(flags[0], Flags{}, true)
}
func TestCallbackOrder(t *testing.T) {
mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{}
s, flags, _ := testSetup([]bool{false, false, false, false}, nil)
ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s)
ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) {
if newState.Equals(flags[0]) {
ns.SetStateSub(n, flags[1], Flags{}, 0)
ns.SetStateSub(n, flags[2], Flags{}, 0)
}
})
ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) {
if newState.Equals(flags[1]) {
ns.SetStateSub(n, flags[3], Flags{}, 0)
}
})
lastState := Flags{}
ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags) {
if !oldState.Equals(lastState) {
t.Fatalf("Wrong callback order")
}
lastState = newState
})
ns.Start()
defer ns.Stop()
ns.SetState(testNode(1), flags[0], Flags{}, 0)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment