summaryrefslogtreecommitdiff
path: root/crypto/merkletrie/trie.go
blob: 6bb2a51851b7b662493dc935245bd416f57b2c5e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
// Copyright (C) 2019-2024 Algorand, Inc.
// This file is part of go-algorand
//
// go-algorand is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// go-algorand is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with go-algorand.  If not, see <https://www.gnu.org/licenses/>.

package merkletrie

import (
	"encoding/binary"
	"errors"

	"github.com/algorand/go-algorand/crypto"
)

const (
	// merkleTreeVersion is the version of the encoded trie. If we ever want to make changes and want to have upgrade path,
	// this would give us the ability to do so.
	merkleTreeVersion = uint64(0x1000000010000000)
	// nodePageVersion is the version of the encoded node. If we ever want to make changes and want to have upgrade path,
	// this would give us the ability to do so.
	nodePageVersion = uint64(0x1000000010000000)
)

// ErrRootPageDecodingFailure is returned if the decoding the root page has failed.
var ErrRootPageDecodingFailure = errors.New("error encountered while decoding root page")

// ErrMismatchingElementLength is returned when an element is being added/removed from the trie that doesn't align with the trie's previous elements length
var ErrMismatchingElementLength = errors.New("mismatching element length")

// ErrMismatchingPageSize is returned when you try to provide an existing trie a committer with a different page size than it was originally created with.
var ErrMismatchingPageSize = errors.New("mismatching page size")

// ErrUnableToEvictPendingCommits is returned if the tree was modified and Evict was called with commit=false
var ErrUnableToEvictPendingCommits = errors.New("unable to evict as pending commits available")

// MemoryConfig used to define the Trie object memory configuration.
type MemoryConfig struct {
	// NodesCountPerPage defines how many nodes each page would contain
	NodesCountPerPage int64
	// CachedNodesCount defines the number of nodes we want to retain in memory between consecutive Evict calls.
	CachedNodesCount int
	// PageFillFactor defines the desired fill ratio of a created page.
	PageFillFactor float32
	// MaxChildrenPagesThreshold define the maximum number of different pages that would be used for a single node's children.
	// it's being evaluated during Commit, for all the updated nodes.
	MaxChildrenPagesThreshold uint64
}

// Trie is a merkle trie intended to efficiently calculate the merkle root of
// unordered elements
type Trie struct {
	root                storedNodeIdentifier
	nextNodeID          storedNodeIdentifier
	lastCommittedNodeID storedNodeIdentifier
	cache               merkleTrieCache
	elementLength       int
}

// Stats structure is a helper for finding underlaying statistics about the trie
type Stats struct {
	NodesCount uint
	LeafCount  uint
	Depth      int
	Size       int
}

// MakeTrie creates a merkle trie
func MakeTrie(committer Committer, memoryConfig MemoryConfig) (*Trie, error) {
	mt := &Trie{
		root:                storedNodeIdentifierNull,
		cache:               merkleTrieCache{},
		nextNodeID:          storedNodeIdentifierBase,
		lastCommittedNodeID: storedNodeIdentifierBase,
	}
	if committer == nil {
		committer = &InMemoryCommitter{}
	} else {
		rootBytes, err := committer.LoadPage(storedNodeIdentifierNull)
		if err == nil {
			if rootBytes != nil {
				var pageSize int64
				pageSize, err = mt.deserialize(rootBytes)
				if err != nil {
					return nil, err
				}
				if pageSize != memoryConfig.NodesCountPerPage {
					return nil, ErrMismatchingPageSize
				}
			}
		} else {
			return nil, err
		}
	}
	mt.cache.initialize(mt, committer, memoryConfig)
	return mt, nil
}

// SetCommitter sets the provided committer as the current committer
func (mt *Trie) SetCommitter(committer Committer) {
	mt.cache.committer = committer
}

// RootHash returns the root hash of all the elements in the trie
func (mt *Trie) RootHash() (crypto.Digest, error) {
	if mt.root == storedNodeIdentifierNull {
		return crypto.Digest{}, nil
	}
	if mt.cache.modified {
		if _, err := mt.Commit(); err != nil {
			return crypto.Digest{}, err
		}
	}
	pnode, err := mt.cache.getNode(mt.root)
	if err != nil {
		return crypto.Digest{}, err
	}

	if pnode.leaf() {
		return crypto.Hash(append([]byte{0}, pnode.hash...)), nil
	}
	return crypto.Hash(append([]byte{1}, pnode.hash...)), nil
}

// Add adds the given hash to the trie.
// returns false if the item already exists.
func (mt *Trie) Add(d []byte) (bool, error) {
	if mt.root == storedNodeIdentifierNull {
		// first item added to the tree.
		var pnode *node
		mt.cache.beginTransaction()
		pnode, mt.root = mt.cache.allocateNewNode()
		mt.cache.commitTransaction()
		pnode.hash = d
		mt.elementLength = len(d)
		return true, nil
	}
	if len(d) != mt.elementLength {
		return false, ErrMismatchingElementLength
	}
	pnode, err := mt.cache.getNode(mt.root)
	if err != nil {
		return false, err
	}
	found, err := pnode.find(&mt.cache, d[:])
	if found || (err != nil) {
		return false, err
	}
	mt.cache.beginTransaction()
	var updatedRoot storedNodeIdentifier
	updatedRoot, err = pnode.add(&mt.cache, d[:], make([]byte, 0, len(d)))
	if err != nil {
		mt.cache.rollbackTransaction()
		return false, err
	}
	mt.cache.deleteNode(mt.root)
	mt.root = updatedRoot
	mt.cache.commitTransaction()
	return true, nil
}

// Delete deletes the given hash to the trie, if such element exists.
// if no such element exists, return false
func (mt *Trie) Delete(d []byte) (bool, error) {
	if mt.root == storedNodeIdentifierNull {
		return false, nil
	}
	if len(d) != mt.elementLength {
		return false, ErrMismatchingElementLength
	}
	pnode, err := mt.cache.getNode(mt.root)
	if err != nil {
		return false, err
	}
	found, err := pnode.find(&mt.cache, d[:])
	if !found || err != nil {
		return false, err
	}
	mt.cache.beginTransaction()
	if pnode.leaf() {
		// remove the root.
		mt.cache.deleteNode(mt.root)
		mt.root = storedNodeIdentifierNull
		mt.cache.commitTransaction()
		mt.elementLength = 0
		return true, nil
	}
	var updatedRoot storedNodeIdentifier
	updatedRoot, err = pnode.remove(&mt.cache, d[:], make([]byte, 0, len(d)))
	if err != nil {
		mt.cache.rollbackTransaction()
		return false, err
	}
	mt.cache.deleteNode(mt.root)
	mt.cache.commitTransaction()
	mt.root = updatedRoot
	return true, nil
}

// GetStats return statistics about the merkle trie
func (mt *Trie) GetStats() (stats Stats, err error) {
	if mt.root == storedNodeIdentifierNull {
		return Stats{}, nil
	}
	pnode, err := mt.cache.getNode(mt.root)
	if err != nil {
		return Stats{}, err
	}
	err = pnode.stats(&mt.cache, &stats, 1)
	return
}

// Commit stores the existings trie using the committer.
func (mt *Trie) Commit() (stats CommitStats, err error) {
	stats, err = mt.cache.commit()
	if err == nil {
		mt.lastCommittedNodeID = mt.nextNodeID
		bytes := mt.serialize()
		err = mt.cache.committer.StorePage(storedNodeIdentifierNull, bytes)
	}
	return
}

// Evict removes elements from the cache that are no longer needed.
func (mt *Trie) Evict(commit bool) (int, error) {
	if commit {
		if mt.cache.modified {
			if _, err := mt.Commit(); err != nil {
				return 0, err
			}
		}
	} else {
		if mt.cache.modified {
			return 0, ErrUnableToEvictPendingCommits
		}
	}
	return mt.cache.evict(), nil
}

// serialize serializes the trie root
func (mt *Trie) serialize() []byte {
	serializedBuffer := make([]byte, 5*binary.MaxVarintLen64) // allocate the worst-case scenario for the trie header.
	version := binary.PutUvarint(serializedBuffer[:], merkleTreeVersion)
	root := binary.PutUvarint(serializedBuffer[version:], uint64(mt.root))
	next := binary.PutUvarint(serializedBuffer[version+root:], uint64(mt.nextNodeID))
	elementLength := binary.PutUvarint(serializedBuffer[version+root+next:], uint64(mt.elementLength))
	pageSizeLength := binary.PutUvarint(serializedBuffer[version+root+next+elementLength:], uint64(mt.cache.nodesPerPage))
	return serializedBuffer[:version+root+next+elementLength+pageSizeLength]
}

// deserialize deserializes the trie root
func (mt *Trie) deserialize(bytes []byte) (int64, error) {
	version, versionLen := binary.Uvarint(bytes[:])
	if versionLen <= 0 {
		return 0, ErrRootPageDecodingFailure
	}
	if version != merkleTreeVersion {
		return 0, ErrRootPageDecodingFailure
	}
	root, rootLen := binary.Uvarint(bytes[versionLen:])
	if rootLen <= 0 {
		return 0, ErrRootPageDecodingFailure
	}
	nextNodeID, nextNodeIDLen := binary.Uvarint(bytes[versionLen+rootLen:])
	if nextNodeIDLen <= 0 {
		return 0, ErrRootPageDecodingFailure
	}
	elemLength, elemLengthLength := binary.Uvarint(bytes[versionLen+rootLen+nextNodeIDLen:])
	if elemLengthLength <= 0 {
		return 0, ErrRootPageDecodingFailure
	}
	pageSize, pageSizeLength := binary.Uvarint(bytes[versionLen+rootLen+nextNodeIDLen+elemLengthLength:])
	if pageSizeLength <= 0 {
		return 0, ErrRootPageDecodingFailure
	}
	mt.root = storedNodeIdentifier(root)
	mt.nextNodeID = storedNodeIdentifier(nextNodeID)
	mt.lastCommittedNodeID = storedNodeIdentifier(nextNodeID)
	mt.elementLength = int(elemLength)
	return int64(pageSize), nil
}