diff options
author | algoidan <79864820+algoidan@users.noreply.github.com> | 2022-08-09 05:41:41 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-08 22:41:41 -0400 |
commit | 433b8815cfb8ac3b5df620dada2e1245b3761517 (patch) | |
tree | 4e03af5530a887a6d1dcaf8d23a295139a408992 | |
parent | 2fc3935e124b1e1047dd33629e91680ab230fe88 (diff) |
Stateproofs: fixes to lowestRound and CR comments (#4369)feature/stateproofs
24 files changed, 315 insertions, 196 deletions
diff --git a/catchup/catchpointService.go b/catchup/catchpointService.go index f4d753243..5d01fa960 100644 --- a/catchup/catchpointService.go +++ b/catchup/catchpointService.go @@ -19,6 +19,7 @@ package catchup import ( "context" "fmt" + "github.com/algorand/go-algorand/stateproof" "sync" "time" @@ -31,7 +32,6 @@ import ( "github.com/algorand/go-algorand/ledger/ledgercore" "github.com/algorand/go-algorand/logging" "github.com/algorand/go-algorand/network" - "github.com/algorand/go-algorand/protocol" ) const ( @@ -465,25 +465,17 @@ func (cs *CatchpointCatchupService) processStageLastestBlockDownload() (err erro return nil } -// lookbackForStateproofSupport calculates the lookback (from topblock round) needed to be downloaded +// lookbackForStateproofsSupport calculates the lookback (from topblock round) needed to be downloaded // in order to support state proofs verification. func lookbackForStateproofsSupport(topBlock *bookkeeping.Block) uint64 { proto := config.Consensus[topBlock.CurrentProtocol] if proto.StateProofInterval == 0 { return 0 } - //calculate the lowest round needed for verify the upcoming state proof txn. - expectedStateProofRound := topBlock.StateProofTracking[protocol.StateProofBasic].StateProofNextRound - stateproofVerificationRound := expectedStateProofRound.SubSaturate(basics.Round(proto.StateProofInterval)) - - lowestRecoveryRound := topBlock.Round().SubSaturate(topBlock.Round() % basics.Round(proto.StateProofInterval)) - lowestRecoveryRound = lowestRecoveryRound.SubSaturate(basics.Round(proto.StateProofInterval * (proto.StateProofMaxRecoveryIntervals + 1))) - - if lowestRecoveryRound > stateproofVerificationRound { - return uint64(topBlock.Round().SubSaturate(lowestRecoveryRound)) - } - - return uint64(topBlock.Round().SubSaturate(stateproofVerificationRound)) + lowestStateProofRound := stateproof.GetOldestExpectedStateProof(&topBlock.BlockHeader) + // in order to be able to confirm lowestStateProofRound we need to have round number: (lowestStateProofRound - stateproofInterval) + lowestStateProofRound = lowestStateProofRound.SubSaturate(basics.Round(proto.StateProofInterval)) + return uint64(topBlock.Round().SubSaturate(lowestStateProofRound)) } // processStageBlocksDownload is the fourth catchpoint catchup stage. It downloads all the reminder of the blocks, verifying each one of them against it's predecessor. diff --git a/components/mocks/mockParticipationRegistry.go b/components/mocks/mockParticipationRegistry.go index da9afad8c..d7a53c36f 100644 --- a/components/mocks/mockParticipationRegistry.go +++ b/components/mocks/mockParticipationRegistry.go @@ -69,9 +69,9 @@ func (m *MockParticipationRegistry) GetForRound(id account.ParticipationID, roun return account.ParticipationRecordForRound{}, nil } -// GetStateProofForRound fetches a record with stateproof secrets for a particular round. -func (m *MockParticipationRegistry) GetStateProofForRound(id account.ParticipationID, round basics.Round) (account.StateProofRecordForRound, error) { - return account.StateProofRecordForRound{}, nil +// GetStateProofSecretsForRound fetches a record with stateproof secrets for a particular round. +func (m *MockParticipationRegistry) GetStateProofSecretsForRound(id account.ParticipationID, round basics.Round) (account.StateProofSecretsForRound, error) { + return account.StateProofSecretsForRound{}, nil } // HasLiveKeys quickly tests to see if there is a valid participation key over some range of rounds diff --git a/crypto/stateproof/builder.go b/crypto/stateproof/builder.go index 42c13beb0..3f85656ab 100644 --- a/crypto/stateproof/builder.go +++ b/crypto/stateproof/builder.go @@ -46,6 +46,7 @@ type Builder struct { lnProvenWeight uint64 provenWeight uint64 strengthTarget uint64 + cachedProof *StateProof } // MakeBuilder constructs an empty builder. After adding enough signatures and signed weight, this builder is used to create a stateproof. @@ -66,6 +67,7 @@ func MakeBuilder(data MessageHash, round uint64, provenWeight uint64, part []bas lnProvenWeight: lnProvenWt, provenWeight: provenWeight, strengthTarget: strengthTarget, + cachedProof: nil, } return b, nil @@ -122,12 +124,13 @@ func (b *Builder) Add(pos uint64, sig merklesignature.Signature) error { b.sigs[pos].Weight = p.Weight b.sigs[pos].Sig = sig b.signedWeight += p.Weight + b.cachedProof = nil // can rebuild a more optimized state proof return nil } // Ready returns whether the state proof is ready to be built. func (b *Builder) Ready() bool { - return b.signedWeight > b.provenWeight + return b.cachedProof != nil || b.signedWeight > b.provenWeight } // SignedWeight returns the total weight of signatures added so far. @@ -167,6 +170,10 @@ again: // Build returns a state proof, if the builder has accumulated // enough signatures to construct it. func (b *Builder) Build() (*StateProof, error) { + if b.cachedProof != nil { + return b.cachedProof, nil + } + if !b.Ready() { return nil, fmt.Errorf("%w: %d <= %d", ErrSignedWeightLessThanProvenWeight, b.signedWeight, b.provenWeight) } @@ -248,6 +255,6 @@ func (b *Builder) Build() (*StateProof, error) { s.SigProofs = *sigProofs s.PartProofs = *partProofs s.PositionsToReveal = revealsSequence - + b.cachedProof = s return s, nil } diff --git a/crypto/stateproof/builder_test.go b/crypto/stateproof/builder_test.go index de8e5a8e7..780262e85 100644 --- a/crypto/stateproof/builder_test.go +++ b/crypto/stateproof/builder_test.go @@ -50,6 +50,8 @@ type paramsForTest struct { partCommitment crypto.GenericDigest numberOfParticipnets uint64 data MessageHash + builder *Builder + sig merklesignature.Signature } const stateProofIntervalForTests = 256 @@ -122,7 +124,7 @@ func generateProofForTesting(a *require.Assertions, doLargeTest bool) paramsForT b, err := MakeBuilder(data, stateProofIntervalForTests, uint64(totalWeight/2), parts, partcom, stateProofStrengthTargetForTests) a.NoError(err) - for i := uint64(0); i < uint64(npart); i++ { + for i := uint64(0); i < uint64(npart)/2+10; i++ { // leave some signature to be added later in the test (if needed) a.False(b.Present(i)) a.NoError(b.IsValid(i, &sigs[i], !doLargeTest)) b.Add(i, sigs[i]) @@ -142,6 +144,8 @@ func generateProofForTesting(a *require.Assertions, doLargeTest bool) paramsForT partCommitment: partcom.Root(), numberOfParticipnets: uint64(npart), data: data, + builder: b, + sig: sig, } return p } @@ -605,6 +609,28 @@ func TestBuilderWithZeroProvenWeight(t *testing.T) { } +func TestBuilder_BuildStateProofCache(t *testing.T) { + partitiontest.PartitionTest(t) + a := require.New(t) + p := generateProofForTesting(a, true) + sp1 := &p.sp + sp2, err := p.builder.Build() + a.NoError(err) + a.Equal(sp1, sp2) // already built, no signatures added + + err = p.builder.Add(p.numberOfParticipnets-1, p.sig) + a.NoError(err) + sp3, err := p.builder.Build() + a.NoError(err) + a.NotEqual(sp1, sp3) // better StateProof with added signature should have been built + + sp4, err := p.builder.Build() + a.NoError(err) + a.Equal(sp3, sp4) + + return +} + func BenchmarkBuildVerify(b *testing.B) { totalWeight := 1000000 npart := 1000 diff --git a/crypto/stateproof/weights.go b/crypto/stateproof/weights.go index 6d32eda48..8d0bdd13d 100644 --- a/crypto/stateproof/weights.go +++ b/crypto/stateproof/weights.go @@ -79,7 +79,7 @@ func verifyWeights(signedWeight uint64, lnProvenWeight uint64, numOfReveals uint // /\ // || // \/ - // numReveals * (x + w * y >= ((strengthTarget) * T + numReveals * P) * y + // numReveals * (x + w * y) >= ((strengthTarget) * T + numReveals * P) * y y, x, w := getSubExpressions(signedWeight) lhs := &big.Int{} lhs.Set(w). diff --git a/daemon/algod/api/server/v2/test/helpers.go b/daemon/algod/api/server/v2/test/helpers.go index 22bc93511..ba2a29549 100644 --- a/daemon/algod/api/server/v2/test/helpers.go +++ b/daemon/algod/api/server/v2/test/helpers.go @@ -80,7 +80,6 @@ var txnPoolGolden = make([]transactions.SignedTxn, 2) // ordinarily mockNode would live in `components/mocks` // but doing this would create an import cycle, as mockNode needs -// but doing this would create an import cycle, as mockNode needs // package `data` and package `node`, which themselves import `mocks` type mockNode struct { ledger v2.LedgerForAPI diff --git a/data/account/participationRegistry.go b/data/account/participationRegistry.go index 72adcf8b3..e36f22451 100644 --- a/data/account/participationRegistry.go +++ b/data/account/participationRegistry.go @@ -94,10 +94,10 @@ type ( ParticipationRecord } - // StateProofRecordForRound contains participant's state proof secrets that corresponds to + // StateProofSecretsForRound contains participant's state proof secrets that corresponds to // one specific round. In Addition, it also returns the participation metadata. // If there are no secrets for the round a nil is returned in Stateproof field. - StateProofRecordForRound struct { + StateProofSecretsForRound struct { ParticipationRecord StateProofSecrets *merklesignature.Signer @@ -247,8 +247,8 @@ type ParticipationRegistry interface { // GetForRound fetches a record with voting secrets for a particular round. GetForRound(id ParticipationID, round basics.Round) (ParticipationRecordForRound, error) - // GetStateProofForRound fetches a record with stateproof secrets for a particular round. - GetStateProofForRound(id ParticipationID, round basics.Round) (StateProofRecordForRound, error) + // GetStateProofSecretsForRound fetches a record with stateproof secrets for a particular round. + GetStateProofSecretsForRound(id ParticipationID, round basics.Round) (StateProofSecretsForRound, error) // HasLiveKeys quickly tests to see if there is a valid participation key over some range of rounds HasLiveKeys(from, to basics.Round) bool @@ -761,14 +761,14 @@ func (db *participationDB) HasLiveKeys(from, to basics.Round) bool { return false } -// GetStateProofForRound returns the state proof data required to sign the compact certificate for this round -func (db *participationDB) GetStateProofForRound(id ParticipationID, round basics.Round) (StateProofRecordForRound, error) { +// GetStateProofSecretsForRound returns the state proof data required to sign the compact certificate for this round +func (db *participationDB) GetStateProofSecretsForRound(id ParticipationID, round basics.Round) (StateProofSecretsForRound, error) { partRecord, err := db.GetForRound(id, round) if err != nil { - return StateProofRecordForRound{}, err + return StateProofSecretsForRound{}, err } - var result StateProofRecordForRound + var result StateProofSecretsForRound result.ParticipationRecord = partRecord.ParticipationRecord var rawStateProofKey []byte err = db.store.Rdb.Atomic(func(ctx context.Context, tx *sql.Tx) error { @@ -790,7 +790,7 @@ func (db *participationDB) GetStateProofForRound(id ParticipationID, round basic return nil }) if err != nil { - return StateProofRecordForRound{}, fmt.Errorf("failed to fetch state proof for round %d: %w", round, err) + return StateProofSecretsForRound{}, fmt.Errorf("failed to fetch state proof for round %d: %w", round, err) } // Init stateproof fields after being able to retrieve key from database @@ -800,7 +800,7 @@ func (db *participationDB) GetStateProofForRound(id ParticipationID, round basic err = protocol.Decode(rawStateProofKey, result.StateProofSecrets.SigningKey) if err != nil { - return StateProofRecordForRound{}, err + return StateProofSecretsForRound{}, err } var rawSignerContext []byte @@ -814,11 +814,11 @@ func (db *participationDB) GetStateProofForRound(id ParticipationID, round basic return nil }) if err != nil { - return StateProofRecordForRound{}, err + return StateProofSecretsForRound{}, err } err = protocol.Decode(rawSignerContext, &result.StateProofSecrets.SignerContext) if err != nil { - return StateProofRecordForRound{}, err + return StateProofSecretsForRound{}, err } return result, nil } diff --git a/data/account/participationRegistry_test.go b/data/account/participationRegistry_test.go index 38a3d89bb..61b78adac 100644 --- a/data/account/participationRegistry_test.go +++ b/data/account/participationRegistry_test.go @@ -936,14 +936,14 @@ func TestAddStateProofKeys(t *testing.T) { err = registry.Flush(10 * time.Second) a.NoError(err) - _, err = registry.GetStateProofForRound(id, basics.Round(1)) + _, err = registry.GetStateProofSecretsForRound(id, basics.Round(1)) a.Error(err) - _, err = registry.GetStateProofForRound(id, basics.Round(2)) + _, err = registry.GetStateProofSecretsForRound(id, basics.Round(2)) a.Error(err) // Make sure we're able to fetch the same data that was put in. for i := uint64(3); i < max; i++ { - r, err := registry.GetStateProofForRound(id, basics.Round(i)) + r, err := registry.GetStateProofSecretsForRound(id, basics.Round(i)) a.NoError(err) if r.StateProofSecrets != nil { @@ -1037,10 +1037,10 @@ func TestGetRoundSecretsWithoutStateProof(t *testing.T) { a.NoError(registry.Flush(defaultTimeout)) - partPerRound, err := registry.GetStateProofForRound(id, 1) + partPerRound, err := registry.GetStateProofSecretsForRound(id, 1) a.Error(err) - partPerRound, err = registry.GetStateProofForRound(id, basics.Round(stateProofIntervalForTests)) + partPerRound, err = registry.GetStateProofSecretsForRound(id, basics.Round(stateProofIntervalForTests)) a.Error(err) // Append key @@ -1052,10 +1052,10 @@ func TestGetRoundSecretsWithoutStateProof(t *testing.T) { a.NoError(registry.Flush(defaultTimeout)) - partPerRound, err = registry.GetStateProofForRound(id, basics.Round(stateProofIntervalForTests)-1) + partPerRound, err = registry.GetStateProofSecretsForRound(id, basics.Round(stateProofIntervalForTests)-1) a.Error(err) - partPerRound, err = registry.GetStateProofForRound(id, basics.Round(stateProofIntervalForTests)) + partPerRound, err = registry.GetStateProofSecretsForRound(id, basics.Round(stateProofIntervalForTests)) a.NoError(err) a.NotNil(partPerRound.StateProofSecrets) @@ -1099,7 +1099,7 @@ func TestDeleteStateProofKeys(t *testing.T) { // Make sure we're able to fetch the same data that was put in. for i := uint64(4); i < maxRound; i += 4 { - r, err := registry.GetStateProofForRound(id, basics.Round(i)) + r, err := registry.GetStateProofSecretsForRound(id, basics.Round(i)) a.NoError(err) a.Equal(keys.findPairForSpecificRound(i).Key, r.StateProofSecrets.SigningKey) @@ -1238,7 +1238,7 @@ func TestParticipationDB_Locking(t *testing.T) { time.Sleep(time.Second) goto repeat } - _, err = registry.GetStateProofForRound(id2, basics.Round(256)) + _, err = registry.GetStateProofSecretsForRound(id2, basics.Round(256)) // The error we're trying to avoid is "database is locked", since we're reading from StateProofKeys table, // while the main thread is updating the Rolling table. a.NoError(err) @@ -1295,7 +1295,7 @@ func TestParticipationDBInstallWhileReading(t *testing.T) { <-appendedKeys // Makes sure we start fetching keys after the append keys operation has already started for i := 0; i < 50; i++ { - _, err = registry.GetStateProofForRound(sampledPartID, basics.Round(256)) + _, err = registry.GetStateProofSecretsForRound(sampledPartID, basics.Round(256)) // The error we're trying to avoid is "database is locked", since we're reading from StateProofKeys table, // while a different go routine is installing new keys. a.NoError(err) diff --git a/data/accountManager.go b/data/accountManager.go index cd79c0633..d44091f80 100644 --- a/data/accountManager.go +++ b/data/accountManager.go @@ -77,10 +77,10 @@ func (manager *AccountManager) Keys(rnd basics.Round) (out []account.Participati } // StateProofKeys returns a list of Participation accounts, and their stateproof secrets -func (manager *AccountManager) StateProofKeys(rnd basics.Round) (out []account.StateProofRecordForRound) { +func (manager *AccountManager) StateProofKeys(rnd basics.Round) (out []account.StateProofSecretsForRound) { for _, part := range manager.registry.GetAll() { if part.OverlapsInterval(rnd, rnd) { - partRndSecrets, err := manager.registry.GetStateProofForRound(part.ParticipationID, rnd) + partRndSecrets, err := manager.registry.GetStateProofSecretsForRound(part.ParticipationID, rnd) if err != nil { manager.log.Errorf("error while loading round secrets from participation registry: %w", err) continue diff --git a/data/pools/transactionPool_test.go b/data/pools/transactionPool_test.go index b559b71e5..d4c863d44 100644 --- a/data/pools/transactionPool_test.go +++ b/data/pools/transactionPool_test.go @@ -1383,7 +1383,7 @@ func TestTStateProofLogging(t *testing.T) { votersRoundHdr, err := mockLedger.BlockHdr(votersRound) require.NoError(t, err) - provenWeight, err := verify.GetProvenWeight(votersRoundHdr, spRoundHdr) + provenWeight, err := verify.GetProvenWeight(&votersRoundHdr, &spRoundHdr) require.NoError(t, err) lookback := votersRound.SubSaturate(basics.Round(proto.StateProofVotersLookback)) diff --git a/ledger/accountdb.go b/ledger/accountdb.go index 99634162e..be242e13e 100644 --- a/ledger/accountdb.go +++ b/ledger/accountdb.go @@ -54,9 +54,9 @@ type accountsDbQueries struct { } type onlineAccountsDbQueries struct { - lookupOnlineStmt *sql.Stmt - lookupOnlineHistoryStmt *sql.Stmt - lookupOnlineTotalsHistoryStmt *sql.Stmt + lookupOnlineStmt *sql.Stmt + lookupOnlineHistoryStmt *sql.Stmt + lookupOnlineTotalsStmt *sql.Stmt } var accountsSchema = []string{ @@ -2517,7 +2517,7 @@ func onlineAccountsInitDbQueries(r db.Queryable) (*onlineAccountsDbQueries, erro return nil, err } - qs.lookupOnlineTotalsHistoryStmt, err = r.Prepare("SELECT data FROM onlineroundparamstail WHERE rnd=?") + qs.lookupOnlineTotalsStmt, err = r.Prepare("SELECT data FROM onlineroundparamstail WHERE rnd=?") if err != nil { return nil, err } @@ -2724,7 +2724,7 @@ func (qs *onlineAccountsDbQueries) lookupOnline(addr basics.Address, rnd basics. func (qs *onlineAccountsDbQueries) lookupOnlineTotalsHistory(round basics.Round) (basics.MicroAlgos, error) { data := ledgercore.OnlineRoundParamsData{} err := db.Retry(func() error { - row := qs.lookupOnlineTotalsHistoryStmt.QueryRow(round) + row := qs.lookupOnlineTotalsStmt.QueryRow(round) var buf []byte err := row.Scan(&buf) if err != nil { diff --git a/ledger/apply/apply.go b/ledger/apply/apply.go index a37e1a86e..34017e383 100644 --- a/ledger/apply/apply.go +++ b/ledger/apply/apply.go @@ -25,8 +25,8 @@ import ( "github.com/algorand/go-algorand/ledger/ledgercore" ) -// StateProofs allows fetching and updating state-proofs state on the ledger -type StateProofs interface { +// StateProofsApplier allows fetching and updating state-proofs state on the ledger +type StateProofsApplier interface { BlockHdr(r basics.Round) (bookkeeping.BlockHeader, error) GetStateProofNextRound() basics.Round SetStateProofNextRound(rnd basics.Round) diff --git a/ledger/apply/stateproof.go b/ledger/apply/stateproof.go index 866c6ef67..fca56c0f3 100644 --- a/ledger/apply/stateproof.go +++ b/ledger/apply/stateproof.go @@ -17,7 +17,9 @@ package apply import ( + "errors" "fmt" + "github.com/algorand/go-algorand/config" "github.com/algorand/go-algorand/data/basics" "github.com/algorand/go-algorand/data/transactions" @@ -25,36 +27,45 @@ import ( "github.com/algorand/go-algorand/stateproof/verify" ) +// Errors for apply stateproof +var ( + ErrStateProofTypeNotSupported = errors.New("state proof type not supported") + ErrExpectedDifferentStateProofRound = errors.New("expected different state proof round") +) + // StateProof applies the StateProof transaction and setting the next StateProof round -func StateProof(tx transactions.StateProofTxnFields, atRound basics.Round, sp StateProofs, validate bool) error { +func StateProof(tx transactions.StateProofTxnFields, atRound basics.Round, sp StateProofsApplier, validate bool) error { spType := tx.StateProofType if spType != protocol.StateProofBasic { - return fmt.Errorf("applyStateProof type %d not supported", spType) + return fmt.Errorf("applyStateProof: %w - type %d ", ErrStateProofTypeNotSupported, spType) } - nextStateProofRnd := sp.GetStateProofNextRound() - - latestRoundInInterval := basics.Round(tx.Message.LastAttestedRound) - latestRoundHdr, err := sp.BlockHdr(latestRoundInInterval) + lastRoundInInterval := basics.Round(tx.Message.LastAttestedRound) + lastRoundHdr, err := sp.BlockHdr(lastRoundInInterval) if err != nil { return err } - proto := config.Consensus[latestRoundHdr.CurrentProtocol] + nextStateProofRnd := sp.GetStateProofNextRound() + if nextStateProofRnd == 0 || nextStateProofRnd != lastRoundInInterval { + return fmt.Errorf("applyStateProof: %w - expecting state proof for %d, but new state proof is for %d", + ErrExpectedDifferentStateProofRound, nextStateProofRnd, lastRoundInInterval) + } + proto := config.Consensus[lastRoundHdr.CurrentProtocol] if validate { - votersRnd := latestRoundInInterval.SubSaturate(basics.Round(proto.StateProofInterval)) + votersRnd := lastRoundInInterval.SubSaturate(basics.Round(proto.StateProofInterval)) votersHdr, err := sp.BlockHdr(votersRnd) if err != nil { return err } - err = verify.ValidateStateProof(&latestRoundHdr, &tx.StateProof, &votersHdr, nextStateProofRnd, atRound, &tx.Message) + err = verify.ValidateStateProof(&lastRoundHdr, &tx.StateProof, &votersHdr, atRound, &tx.Message) if err != nil { return err } } - sp.SetStateProofNextRound(latestRoundInInterval + basics.Round(proto.StateProofInterval)) + sp.SetStateProofNextRound(lastRoundInInterval + basics.Round(proto.StateProofInterval)) return nil } diff --git a/ledger/internal/eval_test.go b/ledger/internal/eval_test.go index 4928a4a0b..8f07612be 100644 --- a/ledger/internal/eval_test.go +++ b/ledger/internal/eval_test.go @@ -224,7 +224,7 @@ func TestCowStateProof(t *testing.T) { Message: msg, } err := apply.StateProof(stateProofTx, atRound, c0, validate) - require.Contains(t, err.Error(), "not supported") + require.ErrorIs(t, err, apply.ErrStateProofTypeNotSupported) // no spRnd block stateProofTx.StateProofType = protocol.StateProofBasic @@ -234,6 +234,20 @@ func TestCowStateProof(t *testing.T) { err = apply.StateProof(stateProofTx, atRound, c0, validate) require.Contains(t, err.Error(), "no block") + // stateproof txn doesn't confirm the next state proof round. expected is in the past + validate = true + stateProofTx.Message.LastAttestedRound = uint64(16) + c0.SetStateProofNextRound(8) + err = apply.StateProof(stateProofTx, atRound, c0, validate) + require.ErrorIs(t, err, apply.ErrExpectedDifferentStateProofRound) + + // stateproof txn doesn't confirm the next state proof round. expected is in the future + validate = true + stateProofTx.Message.LastAttestedRound = uint64(16) + c0.SetStateProofNextRound(32) + err = apply.StateProof(stateProofTx, atRound, c0, validate) + require.ErrorIs(t, err, apply.ErrExpectedDifferentStateProofRound) + // no votersRnd block // this is slightly a mess of things that don't quite line up with likely usage validate = true @@ -248,16 +262,11 @@ func TestCowStateProof(t *testing.T) { spHdr.Round = 15 blocks[spHdr.Round] = spHdr stateProofTx.Message.LastAttestedRound = uint64(spHdr.Round) + c0.SetStateProofNextRound(15) blockErr[13] = noBlockErr err = apply.StateProof(stateProofTx, atRound, c0, validate) require.Contains(t, err.Error(), "no block") - // validate fail - spHdr.Round = 1 - stateProofTx.Message.LastAttestedRound = uint64(spHdr.Round) - err = apply.StateProof(stateProofTx, atRound, c0, validate) - require.Contains(t, err.Error(), "state proof is not in a valid round multiple") - // fall through to no err validate = false err = apply.StateProof(stateProofTx, atRound, c0, validate) diff --git a/ledger/ledger_test.go b/ledger/ledger_test.go index 0ecb32f2c..d925cec61 100644 --- a/ledger/ledger_test.go +++ b/ledger/ledger_test.go @@ -1657,7 +1657,9 @@ func TestListAssetsAndApplications(t *testing.T) { func TestLedgerKeepsOldBlocksForStateProof(t *testing.T) { partitiontest.PartitionTest(t) - maxBlocks := int((config.Consensus[protocol.ConsensusFuture].StateProofMaxRecoveryIntervals + 1) * config.Consensus[protocol.ConsensusFuture].StateProofInterval) + // since the first state proof is expected to happen on stateproofInterval*2 we would start give-up on state proofs we would + // give up on old state proofs only after stateproofInterval*3 + maxBlocks := int((config.Consensus[protocol.ConsensusFuture].StateProofMaxRecoveryIntervals + 2) * config.Consensus[protocol.ConsensusFuture].StateProofInterval) dbName := fmt.Sprintf("%s.%d", t.Name(), crypto.RandUint64()) genesisInitState, initKeys := ledgertesting.GenerateInitState(t, protocol.ConsensusFuture, 10000000000) @@ -1707,6 +1709,8 @@ func TestLedgerKeepsOldBlocksForStateProof(t *testing.T) { backlogPool := execpool.MakeBacklog(nil, 0, execpool.LowPriority, nil) defer backlogPool.Shutdown() + // On this round there is no give up on any state proof - so we would be able to verify an old state proof txn. + // We now create block with stateproof transaction. since we don't want to complicate the test and create // a cryptographically correct stateproof we would make sure that only the crypto part of the verification fails. blk := createBlkWithStateproof(t, maxBlocks, proto, genesisInitState, l, accounts) @@ -1717,6 +1721,7 @@ func TestLedgerKeepsOldBlocksForStateProof(t *testing.T) { addDummyBlock(t, addresses, proto, l, initKeys, genesisInitState) } + l.WaitForCommit(l.Latest()) // at the point the ledger would remove the voters round for the database. // that will cause the stateproof transaction verification to fail because there are // missing blocks @@ -2777,14 +2782,16 @@ func TestVotersReloadFromDiskPassRecoveryPeriod(t *testing.T) { protocol.StateProofBasic: sp, } - for i := uint64(0); i < (proto.StateProofInterval * (proto.StateProofMaxRecoveryIntervals + 1)); i++ { + // we push proto.StateProofInterval * (proto.StateProofMaxRecoveryIntervals + 2) block into the ledger + // the reason for + 2 is the first state proof is on 2*stateproofinterval. + for i := uint64(0); i < (proto.StateProofInterval * (proto.StateProofMaxRecoveryIntervals + 2)); i++ { blk.BlockHeader.Round++ blk.BlockHeader.TimeStamp += 10 err = l.AddBlock(blk, agreement.Certificate{}) require.NoError(t, err) } - // the voters tracker should contains all the voters for each stateproof round. nothing should be removed + // the voters tracker should contain all the voters for each stateproof round. nothing should be removed l.WaitForCommit(blk.BlockHeader.Round) vtSnapshot := l.acctsOnline.voters.votersForRoundCache beforeRemoveVotersLen := len(vtSnapshot) diff --git a/ledger/voters.go b/ledger/voters.go index 73c772658..d0a76a6cd 100644 --- a/ledger/voters.go +++ b/ledger/voters.go @@ -18,6 +18,7 @@ package ledger import ( "fmt" + "github.com/algorand/go-algorand/stateproof" "sync" "github.com/algorand/go-algorand/config" @@ -59,6 +60,14 @@ type votersTracker struct { // the vector commitment to online accounts from the previous such block. // Thus, we maintain X in the votersForRoundCache map until we form a stateproof // for round X+StateProofVotersLookback+StateProofInterval. + // + // In case state proof chain stalls this map would be bounded to StateProofMaxRecoveryIntervals + 3 + // + 1 - since votersForRoundCache needs to contain an entry for a future state proof + // + 1 - since votersForRoundCache needs to contain an entry to verify the earliest state proof + // in the recovery interval. i.e. it needs to have an entry for R-StateProofMaxRecoveryIntervals-StateProofInterval + // to verify R-StateProofMaxRecoveryIntervals + // + 1 would only appear if the sampled round R is: interval - lookback < R < interval. + // in this case, the tracker would not yet remove the old one but will create a new one for future state proof. votersForRoundCache map[basics.Round]*ledgercore.VotersForRound l ledgerForTracker @@ -71,7 +80,7 @@ type votersTracker struct { // votersRoundForStateProofRound computes the round number whose voting participants // will be used to sign the state proof for stateProofRnd. -func votersRoundForStateProofRound(stateProofRnd basics.Round, proto config.ConsensusParams) basics.Round { +func votersRoundForStateProofRound(stateProofRnd basics.Round, proto *config.ConsensusParams) basics.Round { // To form a state proof on period that ends on stateProofRnd, // we need a commitment to the voters StateProofInterval rounds // before that, and the voters information from @@ -84,7 +93,8 @@ func (vt *votersTracker) loadFromDisk(l ledgerForTracker, fetcher ledgercore.Onl vt.votersForRoundCache = make(map[basics.Round]*ledgercore.VotersForRound) vt.onlineAccountsFetcher = fetcher - hdr, err := l.BlockHdr(latestDbRound) + latestRoundInLedger := l.Latest() + hdr, err := l.BlockHdr(latestRoundInLedger) if err != nil { return err } @@ -95,7 +105,8 @@ func (vt *votersTracker) loadFromDisk(l ledgerForTracker, fetcher ledgercore.Onl return nil } - startR := votersRoundForStateProofRound(hdr.StateProofTracking[protocol.StateProofBasic].StateProofNextRound, proto) + startR := stateproof.GetOldestExpectedStateProof(&hdr) + startR = votersRoundForStateProofRound(startR, &proto) // Sanity check: we should never underflow or even reach 0. if startR == 0 { @@ -103,11 +114,12 @@ func (vt *votersTracker) loadFromDisk(l ledgerForTracker, fetcher ledgercore.Onl hdr.StateProofTracking[protocol.StateProofBasic].StateProofNextRound, proto.StateProofInterval, proto.StateProofVotersLookback, startR) } + // we recreate the trees for old rounds. we stop at latestDbRound (where latestDbRound <= latestRoundInLedger) since + // future blocks would be given as part of the replay for r := startR; r <= latestDbRound; r += basics.Round(proto.StateProofInterval) { hdr, err = l.BlockHdr(r) if err != nil { - vt.l.trackerLog().Errorf("votersTracker: loadFromDisk: cannot load tree for round %v, err : %v", r, err) - continue + return err } vt.loadTree(hdr) @@ -192,18 +204,14 @@ func (vt *votersTracker) newBlock(hdr bookkeeping.BlockHeader) { // Since the map is small (Usually 0 - 2 elements and up to StateProofMaxRecoveryIntervals) we decided to keep the code simple // and check for deletion in every round. func (vt *votersTracker) removeOldVoters(hdr bookkeeping.BlockHeader) { - // we calculate the lowest round for recovery according to the newest round (might be different from the rounds on cache) - proto := config.Consensus[hdr.CurrentProtocol] - recentRoundOnRecoveryPeriod := basics.Round(uint64(hdr.Round) - uint64(hdr.Round)%proto.StateProofInterval) - oldestRoundOnRecoveryPeriod := recentRoundOnRecoveryPeriod.SubSaturate(basics.Round(proto.StateProofInterval * proto.StateProofMaxRecoveryIntervals)) + lowestStateProofRound := stateproof.GetOldestExpectedStateProof(&hdr) for r, tr := range vt.votersForRoundCache { commitRound := r + basics.Round(tr.Proto.StateProofVotersLookback) stateProofRound := commitRound + basics.Round(tr.Proto.StateProofInterval) // we remove voters that are no longer needed (i.e StateProofNextRound is larger ) or older than the recover period - if stateProofRound < hdr.StateProofTracking[protocol.StateProofBasic].StateProofNextRound || - stateProofRound <= oldestRoundOnRecoveryPeriod { + if stateProofRound < lowestStateProofRound { delete(vt.votersForRoundCache, r) } } diff --git a/ledger/voters_test.go b/ledger/voters_test.go index a257872ef..9c65e7366 100644 --- a/ledger/voters_test.go +++ b/ledger/voters_test.go @@ -110,7 +110,6 @@ func TestLimitVoterTracker(t *testing.T) { intervalForTest := config.Consensus[protocol.ConsensusFuture].StateProofInterval recoveryIntervalForTests := config.Consensus[protocol.ConsensusFuture].StateProofMaxRecoveryIntervals - numOfIntervals := recoveryIntervalForTests lookbackForTest := config.Consensus[protocol.ConsensusFuture].StateProofVotersLookback accts := []map[basics.Address]basics.AccountData{ledgertesting.RandomAccounts(20, true)} @@ -134,35 +133,60 @@ func TestLimitVoterTracker(t *testing.T) { defer ao.close() i := uint64(1) - // adding blocks to the voterstracker (in order to pass the numOfIntervals*stateproofInterval we add 1) - for ; i < (numOfIntervals*intervalForTest)+1; i++ { + + // since the first state proof is expected to happen on stateproofInterval*2 we would start give-up on state proofs + // after intervalForTest*(recoveryIntervalForTests+3) + + // should not give up on any state proof + for ; i < intervalForTest*(recoveryIntervalForTests+2); i++ { block := randomBlock(basics.Round(i)) block.block.CurrentProtocol = protocol.ConsensusFuture addBlockToAccountsUpdate(block.block, ao) } - a.Equal(recoveryIntervalForTests, uint64(len(ao.voters.votersForRoundCache))) - a.Equal(basics.Round(((i/intervalForTest)-recoveryIntervalForTests+1)*intervalForTest-lookbackForTest), ao.voters.lowestRound(basics.Round(i))) + // the votersForRoundCache should contains recoveryIntervalForTests+2 elements: + // recoveryIntervalForTests - since this is the recovery interval + // + 1 - since votersForRoundCache would contain the votersForRound for the next state proof to come + // + 1 - in order to confirm recoveryIntervalForTests number of state proofs we need recoveryIntervalForTests + 1 headers (for the commitment) + a.Equal(recoveryIntervalForTests+2, uint64(len(ao.voters.votersForRoundCache))) + a.Equal(basics.Round(config.Consensus[protocol.ConsensusFuture].StateProofInterval-lookbackForTest), ao.voters.lowestRound(basics.Round(i))) - // we add numOfIntervals*intervalForTest more blocks. the voter should have only recoveryIntervalForTests number of elements - for ; i < 2*(numOfIntervals*intervalForTest)+1; i++ { + // after adding the round intervalForTest*(recoveryIntervalForTests+3)+1 we expect the voter tracker to remove voters + for ; i < intervalForTest*(recoveryIntervalForTests+3)+1; i++ { block := randomBlock(basics.Round(i)) block.block.CurrentProtocol = protocol.ConsensusFuture addBlockToAccountsUpdate(block.block, ao) } - a.Equal(recoveryIntervalForTests+1, uint64(len(ao.voters.votersForRoundCache))) - a.Equal(basics.Round(((i/intervalForTest)-recoveryIntervalForTests)*intervalForTest-lookbackForTest), ao.voters.lowestRound(basics.Round(i))) + a.Equal(recoveryIntervalForTests+2, uint64(len(ao.voters.votersForRoundCache))) + a.Equal(basics.Round(config.Consensus[protocol.ConsensusFuture].StateProofInterval*2-lookbackForTest), ao.voters.lowestRound(basics.Round(i))) - // we add numOfIntervals*intervalForTest more blocks. the voter should have only recoveryIntervalForTests number of elements - for ; i < 3*(numOfIntervals*intervalForTest)+1; i++ { + // after adding the round intervalForTest*(recoveryIntervalForTests+3)+1 we expect the voter tracker to remove voters + for ; i < intervalForTest*(recoveryIntervalForTests+4)+1; i++ { block := randomBlock(basics.Round(i)) block.block.CurrentProtocol = protocol.ConsensusFuture addBlockToAccountsUpdate(block.block, ao) } + a.Equal(recoveryIntervalForTests+2, uint64(len(ao.voters.votersForRoundCache))) + a.Equal(basics.Round(config.Consensus[protocol.ConsensusFuture].StateProofInterval*3-lookbackForTest), ao.voters.lowestRound(basics.Round(i))) - a.Equal(recoveryIntervalForTests+1, uint64(len(ao.voters.votersForRoundCache))) - a.Equal(basics.Round(((i/intervalForTest)-recoveryIntervalForTests)*intervalForTest-lookbackForTest), ao.voters.lowestRound(basics.Round(i))) + // if the last round of the intervalForTest has not been added to the ledger the votersTracker would + // retain one more element + for ; i < intervalForTest*(recoveryIntervalForTests+5); i++ { + block := randomBlock(basics.Round(i)) + block.block.CurrentProtocol = protocol.ConsensusFuture + addBlockToAccountsUpdate(block.block, ao) + } + a.Equal(recoveryIntervalForTests+3, uint64(len(ao.voters.votersForRoundCache))) + a.Equal(basics.Round(config.Consensus[protocol.ConsensusFuture].StateProofInterval*3-lookbackForTest), ao.voters.lowestRound(basics.Round(i))) + + for ; i < intervalForTest*(recoveryIntervalForTests+5)+1; i++ { + block := randomBlock(basics.Round(i)) + block.block.CurrentProtocol = protocol.ConsensusFuture + addBlockToAccountsUpdate(block.block, ao) + } + a.Equal(recoveryIntervalForTests+2, uint64(len(ao.voters.votersForRoundCache))) + a.Equal(basics.Round(config.Consensus[protocol.ConsensusFuture].StateProofInterval*4-lookbackForTest), ao.voters.lowestRound(basics.Round(i))) } func TestTopNAccountsThatHaveNoMssKeys(t *testing.T) { diff --git a/stateproof/abstractions.go b/stateproof/abstractions.go index 8e77f2860..825b5090e 100644 --- a/stateproof/abstractions.go +++ b/stateproof/abstractions.go @@ -53,7 +53,7 @@ type Network interface { // Accounts captures the aspects of the AccountManager that are used by // this package. type Accounts interface { - StateProofKeys(basics.Round) []account.StateProofRecordForRound + StateProofKeys(basics.Round) []account.StateProofSecretsForRound DeleteStateProofKey(id account.ParticipationID, round basics.Round) error } diff --git a/stateproof/builder.go b/stateproof/builder.go index 02d9f85ee..fd800ebaf 100644 --- a/stateproof/builder.go +++ b/stateproof/builder.go @@ -21,6 +21,7 @@ import ( "database/sql" "encoding/binary" "fmt" + "sort" "github.com/algorand/go-algorand/config" "github.com/algorand/go-algorand/crypto/stateproof" @@ -65,7 +66,7 @@ func (spw *Worker) makeBuilderForRound(rnd basics.Round) (builder, error) { return builder{}, err } - provenWeight, err := verify.GetProvenWeight(votersHdr, hdr) + provenWeight, err := verify.GetProvenWeight(&votersHdr, &hdr) if err != nil { return builder{}, err } @@ -174,8 +175,7 @@ func (spw *Worker) handleSig(sfa sigFromAddr, sender network.Peer) (network.Forw latest := spw.ledger.Latest() latestHdr, err := spw.ledger.BlockHdr(latest) if err != nil { - // The latest block in the ledger should never disappear, so this should never happen - return network.Disconnect, err + return network.Ignore, err } if sfa.Round < latestHdr.StateProofTracking[protocol.StateProofBasic].StateProofNextRound { @@ -273,8 +273,8 @@ func (spw *Worker) builder(latest basics.Round) { continue } - spw.deleteOldSigs(hdr) - spw.deleteOldBuilders(hdr) + spw.deleteOldSigs(&hdr) + spw.deleteOldBuilders(&hdr) // Broadcast signatures based on the previous block(s) that // were agreed upon. This ensures that, if we send a signature @@ -357,29 +357,8 @@ func (spw *Worker) broadcastSigs(brnd basics.Round, proto config.ConsensusParams } } -func lowestRoundToRemove(currentHdr bookkeeping.BlockHeader) basics.Round { - proto := config.Consensus[currentHdr.CurrentProtocol] - nextStateProofRnd := currentHdr.StateProofTracking[protocol.StateProofBasic].StateProofNextRound - if proto.StateProofInterval == 0 { - return nextStateProofRnd - } - - recentRoundOnRecoveryPeriod := basics.Round(uint64(currentHdr.Round) - uint64(currentHdr.Round)%proto.StateProofInterval) - oldestRoundOnRecoveryPeriod := recentRoundOnRecoveryPeriod.SubSaturate(basics.Round(proto.StateProofInterval * proto.StateProofMaxRecoveryIntervals)) - // we add +1 to this number since we want exactly StateProofMaxRecoveryIntervals elements in the history - oldestRoundOnRecoveryPeriod++ - - var oldestRoundToRemove basics.Round - if oldestRoundOnRecoveryPeriod > nextStateProofRnd { - oldestRoundToRemove = oldestRoundOnRecoveryPeriod - } else { - oldestRoundToRemove = nextStateProofRnd - } - return oldestRoundToRemove -} - -func (spw *Worker) deleteOldSigs(currentHdr bookkeeping.BlockHeader) { - oldestRoundToRemove := lowestRoundToRemove(currentHdr) +func (spw *Worker) deleteOldSigs(currentHdr *bookkeeping.BlockHeader) { + oldestRoundToRemove := GetOldestExpectedStateProof(currentHdr) err := spw.db.Atomic(func(ctx context.Context, tx *sql.Tx) error { return deletePendingSigsBeforeRound(tx, oldestRoundToRemove) @@ -389,8 +368,8 @@ func (spw *Worker) deleteOldSigs(currentHdr bookkeeping.BlockHeader) { } } -func (spw *Worker) deleteOldBuilders(currentHdr bookkeeping.BlockHeader) { - oldestRoundToRemove := lowestRoundToRemove(currentHdr) +func (spw *Worker) deleteOldBuilders(currentHdr *bookkeeping.BlockHeader) { + oldestRoundToRemove := GetOldestExpectedStateProof(currentHdr) spw.mu.Lock() defer spw.mu.Unlock() @@ -406,9 +385,16 @@ func (spw *Worker) tryBroadcast() { spw.mu.Lock() defer spw.mu.Unlock() - for rnd, b := range spw.builders { + sortedRounds := make([]basics.Round, 0, len(spw.builders)) + for rnd := range spw.builders { + sortedRounds = append(sortedRounds, rnd) + } + sort.Slice(sortedRounds, func(i, j int) bool { return sortedRounds[i] < sortedRounds[j] }) + + for _, rnd := range sortedRounds { // Iterate over the builders in a sequential manner + b := spw.builders[rnd] firstValid := spw.ledger.Latest() - acceptableWeight := verify.AcceptableStateProofWeight(b.votersHdr, firstValid, logging.Base()) + acceptableWeight := verify.AcceptableStateProofWeight(&b.votersHdr, firstValid, logging.Base()) if b.SignedWeight() < acceptableWeight { // Haven't signed enough to build the state proof at this time.. continue @@ -438,6 +424,9 @@ func (spw *Worker) tryBroadcast() { err = spw.txnSender.BroadcastInternalSignedTxGroup([]transactions.SignedTxn{stxn}) if err != nil { spw.log.Warnf("spw.tryBroadcast: broadcasting state proof txn for %d: %v", rnd, err) + // if this StateProofTxn was rejected, the next one would be rejected as well since state proof should be added in + // a sequential order + break } } } diff --git a/stateproof/recovery.go b/stateproof/recovery.go new file mode 100644 index 000000000..5902a603e --- /dev/null +++ b/stateproof/recovery.go @@ -0,0 +1,42 @@ +// Copyright (C) 2019-2022 Algorand, Inc. +// This file is part of go-algorand +// +// go-algorand is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// go-algorand is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with go-algorand. If not, see <https://www.gnu.org/licenses/>. + +package stateproof + +import ( + "github.com/algorand/go-algorand/config" + "github.com/algorand/go-algorand/data/basics" + "github.com/algorand/go-algorand/data/bookkeeping" + "github.com/algorand/go-algorand/protocol" +) + +// GetOldestExpectedStateProof returns the lowest round for which the node should create a state proof. +func GetOldestExpectedStateProof(latestHeader *bookkeeping.BlockHeader) basics.Round { + proto := config.Consensus[latestHeader.CurrentProtocol] + if proto.StateProofInterval == 0 { + return 0 + } + + recentRoundOnRecoveryPeriod := basics.Round(uint64(latestHeader.Round) - uint64(latestHeader.Round)%proto.StateProofInterval) + oldestRoundOnRecoveryPeriod := recentRoundOnRecoveryPeriod.SubSaturate(basics.Round(proto.StateProofInterval * (proto.StateProofMaxRecoveryIntervals))) + + nextStateproofRound := latestHeader.StateProofTracking[protocol.StateProofBasic].StateProofNextRound + + if nextStateproofRound > oldestRoundOnRecoveryPeriod { + return nextStateproofRound + } + return oldestRoundOnRecoveryPeriod +} diff --git a/stateproof/stateproofMessageGenerator_test.go b/stateproof/stateproofMessageGenerator_test.go index 9c79df8b8..ab3399e51 100644 --- a/stateproof/stateproofMessageGenerator_test.go +++ b/stateproof/stateproofMessageGenerator_test.go @@ -45,7 +45,7 @@ type workerForStateProofMessageTests struct { w *testWorkerStubs } -func (s *workerForStateProofMessageTests) StateProofKeys(round basics.Round) []account.StateProofRecordForRound { +func (s *workerForStateProofMessageTests) StateProofKeys(round basics.Round) []account.StateProofSecretsForRound { return s.w.StateProofKeys(round) } diff --git a/stateproof/verify/stateproof.go b/stateproof/verify/stateproof.go index 1869aa74a..66c6f09d4 100644 --- a/stateproof/verify/stateproof.go +++ b/stateproof/verify/stateproof.go @@ -30,13 +30,12 @@ import ( ) var ( - errStateProofCrypto = errors.New("state proof crypto error") - errStateProofParamCreation = errors.New("state proof param creation error") - errStateProofNotEnabled = errors.New("state proofs are not enabled") - errNotAtRightMultiple = errors.New("state proof is not in a valid round multiple") - errInvalidVotersRound = errors.New("invalid voters round") - errExpectedDifferentStateProofRound = errors.New("expected different state proof round") - errInsufficientWeight = errors.New("insufficient state proof weight") + errStateProofCrypto = errors.New("state proof crypto error") + errStateProofParamCreation = errors.New("state proof param creation error") + errStateProofNotEnabled = errors.New("state proofs are not enabled") + errNotAtRightMultiple = errors.New("state proof is not in a valid round multiple") + errInvalidVotersRound = errors.New("invalid voters round") + errInsufficientWeight = errors.New("insufficient state proof weight") ) // AcceptableStateProofWeight computes the acceptable signed weight @@ -47,7 +46,7 @@ var ( // (votersHdr.Round(), votersHdr.Round()+StateProofInterval]. // // logger must not be nil; use at least logging.Base() -func AcceptableStateProofWeight(votersHdr bookkeeping.BlockHeader, firstValid basics.Round, logger logging.Logger) uint64 { +func AcceptableStateProofWeight(votersHdr *bookkeeping.BlockHeader, firstValid basics.Round, logger logging.Logger) uint64 { proto := config.Consensus[votersHdr.CurrentProtocol] latestRoundInProof := votersHdr.Round + basics.Round(proto.StateProofInterval) total := votersHdr.StateProofTracking[protocol.StateProofBasic].StateProofOnlineTotalWeight @@ -106,7 +105,7 @@ func AcceptableStateProofWeight(votersHdr bookkeeping.BlockHeader, firstValid ba // GetProvenWeight computes the parameters for building or verifying // a state proof for the interval (votersHdr, latestRoundInProofHdr], using voters from block votersHdr. -func GetProvenWeight(votersHdr bookkeeping.BlockHeader, latestRoundInProofHdr bookkeeping.BlockHeader) (uint64, error) { +func GetProvenWeight(votersHdr *bookkeeping.BlockHeader, latestRoundInProofHdr *bookkeeping.BlockHeader) (uint64, error) { proto := config.Consensus[votersHdr.CurrentProtocol] if proto.StateProofInterval == 0 { @@ -137,7 +136,7 @@ func GetProvenWeight(votersHdr bookkeeping.BlockHeader, latestRoundInProofHdr bo } // ValidateStateProof checks that a state proof is valid. -func ValidateStateProof(latestRoundInIntervalHdr *bookkeeping.BlockHeader, stateProof *stateproof.StateProof, votersHdr *bookkeeping.BlockHeader, nextStateProofRnd basics.Round, atRound basics.Round, msg *stateproofmsg.Message) error { +func ValidateStateProof(latestRoundInIntervalHdr *bookkeeping.BlockHeader, stateProof *stateproof.StateProof, votersHdr *bookkeeping.BlockHeader, atRound basics.Round, msg *stateproofmsg.Message) error { proto := config.Consensus[latestRoundInIntervalHdr.CurrentProtocol] if proto.StateProofInterval == 0 { @@ -154,18 +153,13 @@ func ValidateStateProof(latestRoundInIntervalHdr *bookkeeping.BlockHeader, state latestRoundInIntervalHdr.Round, votersRound, votersHdr.Round, errInvalidVotersRound) } - if nextStateProofRnd == 0 || nextStateProofRnd != latestRoundInIntervalHdr.Round { - return fmt.Errorf("expecting state proof for %d, but new state proof is for %d (voters %d):%w", - nextStateProofRnd, latestRoundInIntervalHdr.Round, votersRound, errExpectedDifferentStateProofRound) - } - - acceptableWeight := AcceptableStateProofWeight(*votersHdr, atRound, logging.Base()) + acceptableWeight := AcceptableStateProofWeight(votersHdr, atRound, logging.Base()) if stateProof.SignedWeight < acceptableWeight { return fmt.Errorf("insufficient weight at round %d: %d < %d: %w", atRound, stateProof.SignedWeight, acceptableWeight, errInsufficientWeight) } - provenWeight, err := GetProvenWeight(*votersHdr, *latestRoundInIntervalHdr) + provenWeight, err := GetProvenWeight(votersHdr, latestRoundInIntervalHdr) if err != nil { return fmt.Errorf("%v: %w", err, errStateProofParamCreation) } diff --git a/stateproof/verify/stateproof_test.go b/stateproof/verify/stateproof_test.go index b3dfc02c7..38a76c2b8 100644 --- a/stateproof/verify/stateproof_test.go +++ b/stateproof/verify/stateproof_test.go @@ -37,13 +37,11 @@ func TestValidateStateProof(t *testing.T) { spHdr := &bookkeeping.BlockHeader{} sp := &stateproof.StateProof{} votersHdr := &bookkeeping.BlockHeader{} - var nextSPRnd basics.Round var atRound basics.Round msg := &stateproofmsg.Message{BlockHeadersCommitment: []byte("this is an arbitrary message")} // will definitely fail with nothing set up - err := ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) - t.Log(err) + err := ValidateStateProof(spHdr, sp, votersHdr, atRound, msg) require.ErrorIs(t, err, errStateProofNotEnabled) spHdr.CurrentProtocol = "TestValidateStateProof" @@ -54,33 +52,20 @@ func TestValidateStateProof(t *testing.T) { proto.StateProofWeightThreshold = (1 << 32) * 30 / 100 config.Consensus[spHdr.CurrentProtocol] = proto - err = ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) - // still err, but a different err case to cover - t.Log(err) + err = ValidateStateProof(spHdr, sp, votersHdr, atRound, msg) require.ErrorIs(t, err, errNotAtRightMultiple) spHdr.Round = 4 votersHdr.Round = 4 - err = ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) - // still err, but a different err case to cover - t.Log(err) + err = ValidateStateProof(spHdr, sp, votersHdr, atRound, msg) require.ErrorIs(t, err, errInvalidVotersRound) votersHdr.Round = 2 - err = ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) - // still err, but a different err case to cover - t.Log(err) - require.ErrorIs(t, err, errExpectedDifferentStateProofRound) - - nextSPRnd = 4 - err = ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) - // still err, but a different err case to cover - t.Log(err) + err = ValidateStateProof(spHdr, sp, votersHdr, atRound, msg) require.ErrorIs(t, err, errStateProofParamCreation) votersHdr.CurrentProtocol = spHdr.CurrentProtocol - err = ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) - t.Log(err) + err = ValidateStateProof(spHdr, sp, votersHdr, atRound, msg) // since proven weight is zero, we cann't create the verifier require.ErrorIs(t, err, stateproof.ErrIllegalInputForLnApprox) @@ -88,28 +73,23 @@ func TestValidateStateProof(t *testing.T) { cc := votersHdr.StateProofTracking[protocol.StateProofBasic] cc.StateProofOnlineTotalWeight.Raw = 100 votersHdr.StateProofTracking[protocol.StateProofBasic] = cc - err = ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) - // still err, but a different err case to cover - t.Log(err) + err = ValidateStateProof(spHdr, sp, votersHdr, atRound, msg) require.ErrorIs(t, err, errInsufficientWeight) // Require 100% of the weight to be signed in order to accept stateproof before interval/2 rounds has passed from the latest round attested (optimal case) sp.SignedWeight = 99 // suboptimal signed weight - err = ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) - t.Log(err) + err = ValidateStateProof(spHdr, sp, votersHdr, atRound, msg) require.ErrorIs(t, err, errInsufficientWeight) latestRoundInProof := votersHdr.Round + basics.Round(proto.StateProofInterval) atRound = latestRoundInProof + basics.Round(proto.StateProofInterval/2) - err = ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) - t.Log(err) + err = ValidateStateProof(spHdr, sp, votersHdr, atRound, msg) require.ErrorIs(t, err, errInsufficientWeight) // This suboptimal signed weight should be enough for this round atRound++ - err = ValidateStateProof(spHdr, sp, votersHdr, nextSPRnd, atRound, msg) + err = ValidateStateProof(spHdr, sp, votersHdr, atRound, msg) // still err, but a different err case to cover - t.Log(err) require.ErrorIs(t, err, errStateProofCrypto) // Above cases leave validateStateProof() with 100% coverage. @@ -127,25 +107,25 @@ func TestAcceptableStateProofWeight(t *testing.T) { proto := config.Consensus[votersHdr.CurrentProtocol] proto.StateProofInterval = 2 config.Consensus[votersHdr.CurrentProtocol] = proto - out := AcceptableStateProofWeight(votersHdr, firstValid, logger) + out := AcceptableStateProofWeight(&votersHdr, firstValid, logger) require.Equal(t, uint64(0), out) votersHdr.StateProofTracking = make(map[protocol.StateProofType]bookkeeping.StateProofTrackingData) cc := votersHdr.StateProofTracking[protocol.StateProofBasic] cc.StateProofOnlineTotalWeight.Raw = 100 votersHdr.StateProofTracking[protocol.StateProofBasic] = cc - out = AcceptableStateProofWeight(votersHdr, firstValid, logger) + out = AcceptableStateProofWeight(&votersHdr, firstValid, logger) require.Equal(t, uint64(100), out) // this should exercise the second return case firstValid = basics.Round(3) - out = AcceptableStateProofWeight(votersHdr, firstValid, logger) + out = AcceptableStateProofWeight(&votersHdr, firstValid, logger) require.Equal(t, uint64(100), out) firstValid = basics.Round(6) proto.StateProofWeightThreshold = 999999999 config.Consensus[votersHdr.CurrentProtocol] = proto - out = AcceptableStateProofWeight(votersHdr, firstValid, logger) + out = AcceptableStateProofWeight(&votersHdr, firstValid, logger) require.Equal(t, uint64(0x17), out) proto.StateProofInterval = 10000 @@ -156,7 +136,7 @@ func TestAcceptableStateProofWeight(t *testing.T) { votersHdr.StateProofTracking[protocol.StateProofBasic] = cc proto.StateProofWeightThreshold = 0x7fffffff config.Consensus[votersHdr.CurrentProtocol] = proto - out = AcceptableStateProofWeight(votersHdr, firstValid, logger) + out = AcceptableStateProofWeight(&votersHdr, firstValid, logger) require.Equal(t, uint64(0x4cd35a85213a92a2), out) // Covers everything except "overflow that shouldn't happen" branches @@ -168,7 +148,7 @@ func TestStateProofParams(t *testing.T) { var votersHdr bookkeeping.BlockHeader var hdr bookkeeping.BlockHeader - _, err := GetProvenWeight(votersHdr, hdr) + _, err := GetProvenWeight(&votersHdr, &hdr) require.Error(t, err) // not enabled votersHdr.CurrentProtocol = "TestStateProofParams" @@ -176,12 +156,12 @@ func TestStateProofParams(t *testing.T) { proto.StateProofInterval = 2 config.Consensus[votersHdr.CurrentProtocol] = proto votersHdr.Round = 1 - _, err = GetProvenWeight(votersHdr, hdr) + _, err = GetProvenWeight(&votersHdr, &hdr) require.Error(t, err) // wrong round votersHdr.Round = 2 hdr.Round = 3 - _, err = GetProvenWeight(votersHdr, hdr) + _, err = GetProvenWeight(&votersHdr, &hdr) require.Error(t, err) // wrong round // Covers all cases except overflow diff --git a/stateproof/worker_test.go b/stateproof/worker_test.go index d8afb2807..0034b45b0 100644 --- a/stateproof/worker_test.go +++ b/stateproof/worker_test.go @@ -19,6 +19,7 @@ package stateproof import ( "context" "database/sql" + "encoding/binary" "fmt" "io/ioutil" "strings" @@ -113,7 +114,7 @@ func (s *testWorkerStubs) addBlock(spNextRound basics.Round) { } } -func (s *testWorkerStubs) StateProofKeys(rnd basics.Round) (out []account.StateProofRecordForRound) { +func (s *testWorkerStubs) StateProofKeys(rnd basics.Round) (out []account.StateProofSecretsForRound) { for _, part := range s.keys { partRecord := account.ParticipationRecord{ ParticipationID: part.ID(), @@ -130,7 +131,7 @@ func (s *testWorkerStubs) StateProofKeys(rnd basics.Round) (out []account.StateP Voting: part.Voting, } signerInRound := part.StateProofSecrets.GetSigner(uint64(rnd)) - partRecordForRound := account.StateProofRecordForRound{ + partRecordForRound := account.StateProofSecretsForRound{ ParticipationRecord: partRecord, StateProofSecrets: signerInRound, } @@ -654,7 +655,6 @@ func TestWorkerBuildersRecoveryLimit(t *testing.T) { a := require.New(t) proto := config.Consensus[protocol.ConsensusFuture] - expectedStateProofs := proto.StateProofMaxRecoveryIntervals + 1 var keys []account.Participation for i := 0; i < 10; i++ { var parent basics.Address @@ -671,29 +671,55 @@ func TestWorkerBuildersRecoveryLimit(t *testing.T) { s.advanceLatest(proto.StateProofInterval + proto.StateProofInterval/2) - for iter := uint64(0); iter < expectedStateProofs; iter++ { + for iter := uint64(0); iter < proto.StateProofMaxRecoveryIntervals+1; iter++ { s.advanceLatest(proto.StateProofInterval) tx := <-s.txmsg a.Equal(tx.Txn.Type, protocol.StateProofTx) } + // since this test involves go routine, we would like to make sure that when + // we sample the builder it already processed our current round. + // in order to that, we wait for singer and the builder to wait. + // then we push one more round so the builder could process it (since the builder might skip rounds) err := waitForBuilderAndSignerToWaitOnRound(s) a.NoError(err) - s.mu.Lock() s.addBlock(basics.Round(proto.StateProofInterval * 2)) s.mu.Unlock() - err = waitForBuilderAndSignerToWaitOnRound(s) a.NoError(err) - a.Equal(proto.StateProofMaxRecoveryIntervals, uint64(len(w.builders))) + + // should not give up on rounds + a.Equal(proto.StateProofMaxRecoveryIntervals+1, uint64(len(w.builders))) var roundSigs map[basics.Round][]pendingSig err = w.db.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { roundSigs, err = getPendingSigs(tx) return }) - a.Equal(proto.StateProofMaxRecoveryIntervals, uint64(len(roundSigs))) + a.Equal(proto.StateProofMaxRecoveryIntervals+1, uint64(len(roundSigs))) + + s.advanceLatest(proto.StateProofInterval) + tx := <-s.txmsg + a.Equal(tx.Txn.Type, protocol.StateProofTx) + + err = waitForBuilderAndSignerToWaitOnRound(s) + a.NoError(err) + s.mu.Lock() + s.addBlock(basics.Round(proto.StateProofInterval * 2)) + s.mu.Unlock() + err = waitForBuilderAndSignerToWaitOnRound(s) + a.NoError(err) + + // should not give up on rounds + a.Equal(proto.StateProofMaxRecoveryIntervals+1, uint64(len(w.builders))) + + roundSigs = make(map[basics.Round][]pendingSig) + err = w.db.Atomic(func(ctx context.Context, tx *sql.Tx) (err error) { + roundSigs, err = getPendingSigs(tx) + return + }) + a.Equal(proto.StateProofMaxRecoveryIntervals+1, uint64(len(roundSigs))) } func waitForBuilderAndSignerToWaitOnRound(s *testWorkerStubs) error { @@ -728,18 +754,23 @@ const ( sigAlternateOrigin ) -// getSignaturesInDatabase sets up the db with signatures +// getSignaturesInDatabase sets up the db with signatures. This function supports creating up to StateProofInterval/2 address. func getSignaturesInDatabase(t *testing.T, numAddresses int, sigFrom sigOrigin) ( signatureBcasted map[basics.Address]int, fromThisNode map[basics.Address]bool, tns *testWorkerStubs, spw *Worker) { + // Some tests rely on having only one signature being broadcast at a single round. + // for that we need to make sure that addresses won't fall into the same broadcast round. + // For that same reason we can't have more than StateProofInterval / 2 address + require.LessOrEqual(t, uint64(numAddresses), config.Consensus[protocol.ConsensusFuture].StateProofInterval/2) + // Prepare the addresses and the keys signatureBcasted = make(map[basics.Address]int) fromThisNode = make(map[basics.Address]bool) var keys []account.Participation for i := 0; i < numAddresses; i++ { var parent basics.Address - crypto.RandBytes(parent[:]) + binary.LittleEndian.PutUint64(parent[:], uint64(i)) p := newPartKey(t, parent) defer p.Close() keys = append(keys, p.Participation) |