summaryrefslogtreecommitdiff
path: root/data/basics/fields_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'data/basics/fields_test.go')
-rw-r--r--data/basics/fields_test.go201
1 files changed, 201 insertions, 0 deletions
diff --git a/data/basics/fields_test.go b/data/basics/fields_test.go
new file mode 100644
index 000000000..8027fa29a
--- /dev/null
+++ b/data/basics/fields_test.go
@@ -0,0 +1,201 @@
+// Copyright (C) 2019-2021 Algorand, Inc.
+// This file is part of go-algorand
+//
+// go-algorand is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// go-algorand is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with go-algorand. If not, see <https://www.gnu.org/licenses/>.
+
+package basics_test
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/algorand/go-algorand/data/basics"
+ "github.com/algorand/go-algorand/data/bookkeeping"
+ "github.com/algorand/go-algorand/test/partitiontest"
+ "github.com/stretchr/testify/require"
+)
+
+type typePath []string
+
+func (p typePath) addMapKey() typePath {
+ return append(p, "map_key")
+}
+
+func (p typePath) addValue() typePath {
+ return append(p, "value")
+}
+
+func (p typePath) addField(fieldName string) typePath {
+ return append(p, "field "+fieldName)
+}
+
+func (p typePath) validatePathFrom(t reflect.Type) error {
+ if len(p) == 0 {
+ // path is empty, so it's vacuously valid
+ return nil
+ }
+
+ value := p[0]
+ switch {
+ case value == "map_key":
+ return p[1:].validatePathFrom(t.Key())
+ case value == "value":
+ return p[1:].validatePathFrom(t.Elem())
+ case strings.HasPrefix(value, "field "):
+ fieldName := value[len("field "):]
+ fieldType, ok := t.FieldByName(fieldName)
+ if !ok {
+ return fmt.Errorf("Type '%s' does not have the field '%s'", t.Name(), fieldName)
+ }
+ return p[1:].validatePathFrom(fieldType.Type)
+ default:
+ return fmt.Errorf("Unexpected item in path: %s", value)
+ }
+}
+
+func (p typePath) Equals(other typePath) bool {
+ if len(p) != len(other) {
+ return false
+ }
+ for i := range p {
+ if p[i] != other[i] {
+ return false
+ }
+ }
+ return true
+}
+
+func (p typePath) String() string {
+ return strings.Join(p, "->")
+}
+
+func checkReferencedTypes(seen map[reflect.Type]bool, path typePath, typeStack []reflect.Type, check func(path typePath, stack []reflect.Type) bool) {
+ currentType := typeStack[len(typeStack)-1]
+
+ if _, seenType := seen[currentType]; seenType {
+ return
+ }
+
+ if !check(path, typeStack) {
+ // if currentType is not ok, don't visit its children
+ return
+ }
+
+ // add currentType to seen set, to avoid infinite recursion if currentType references itself
+ seen[currentType] = true
+
+ // after currentType's children are visited, "forget" the type, so we can examine it again if needed
+ // if this didn't happen, only 1 error per invalid type would get reported
+ defer delete(seen, currentType)
+
+ switch currentType.Kind() {
+ case reflect.Map:
+ newPath := path.addMapKey()
+ newStack := append(typeStack, currentType.Key())
+ checkReferencedTypes(seen, newPath, newStack, check)
+ fallthrough
+ case reflect.Array, reflect.Slice, reflect.Ptr:
+ newPath := path.addValue()
+ newStack := append(typeStack, currentType.Elem())
+ checkReferencedTypes(seen, newPath, newStack, check)
+ case reflect.Struct:
+ for i := 0; i < currentType.NumField(); i++ {
+ field := currentType.Field(i)
+ newPath := path.addField(field.Name)
+ newStack := append(typeStack, field.Type)
+ checkReferencedTypes(seen, newPath, newStack, check)
+ }
+ }
+}
+
+func makeTypeCheckFunction(t *testing.T, exceptions []typePath, startType reflect.Type) func(path typePath, stack []reflect.Type) bool {
+ for _, exception := range exceptions {
+ err := exception.validatePathFrom(startType)
+ require.NoError(t, err)
+ }
+
+ return func(path typePath, stack []reflect.Type) bool {
+ currentType := stack[len(stack)-1]
+
+ for _, exception := range exceptions {
+ if path.Equals(exception) {
+ t.Logf("Skipping exception for path: %s", path.String())
+ return true
+ }
+ }
+
+ switch currentType.Kind() {
+ case reflect.String:
+ t.Errorf("Invalid string type referenced from %s. Use []byte instead. Full path: %s", startType.Name(), path.String())
+ return false
+ case reflect.Chan, reflect.Func, reflect.Interface, reflect.UnsafePointer:
+ // raise an error if one of these strange types is referenced too
+ t.Errorf("Invalid type %s referenced from %s. Full path: %s", currentType.Name(), startType.Name(), path.String())
+ return false
+ default:
+ return true
+ }
+ }
+}
+
+func TestBlockFields(t *testing.T) {
+ partitiontest.PartitionTest(t)
+
+ typeToCheck := reflect.TypeOf(bookkeeping.Block{})
+
+ // These exceptions are for pre-existing usages of string. Only add to this list if you really need to use string.
+ exceptions := []typePath{
+ typePath{}.addField("BlockHeader").addField("GenesisID"),
+ typePath{}.addField("BlockHeader").addField("UpgradeState").addField("CurrentProtocol"),
+ typePath{}.addField("BlockHeader").addField("UpgradeState").addField("NextProtocol"),
+ typePath{}.addField("BlockHeader").addField("UpgradeVote").addField("UpgradePropose"),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("SignedTxn").addField("Txn").addField("Type"),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("SignedTxn").addField("Txn").addField("Header").addField("GenesisID"),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("SignedTxn").addField("Txn").addField("AssetConfigTxnFields").addField("AssetParams").addField("UnitName"),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("SignedTxn").addField("Txn").addField("AssetConfigTxnFields").addField("AssetParams").addField("AssetName"),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("SignedTxn").addField("Txn").addField("AssetConfigTxnFields").addField("AssetParams").addField("URL"),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("ApplyData").addField("EvalDelta").addField("GlobalDelta").addMapKey(),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("ApplyData").addField("EvalDelta").addField("GlobalDelta").addValue().addField("Bytes"),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("ApplyData").addField("EvalDelta").addField("LocalDeltas").addValue().addMapKey(),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("ApplyData").addField("EvalDelta").addField("LocalDeltas").addValue().addValue().addField("Bytes"),
+ typePath{}.addField("Payset").addValue().addField("SignedTxnWithAD").addField("ApplyData").addField("EvalDelta").addField("Logs").addValue(),
+ }
+
+ seen := make(map[reflect.Type]bool)
+
+ checkReferencedTypes(seen, nil, []reflect.Type{typeToCheck}, makeTypeCheckFunction(t, exceptions, typeToCheck))
+}
+
+func TestAccountDataFields(t *testing.T) {
+ partitiontest.PartitionTest(t)
+
+ typeToCheck := reflect.TypeOf(basics.AccountData{})
+
+ // These exceptions are for pre-existing usages of string. Only add to this list if you really need to use string.
+ exceptions := []typePath{
+ typePath{}.addField("AssetParams").addValue().addField("UnitName"),
+ typePath{}.addField("AssetParams").addValue().addField("AssetName"),
+ typePath{}.addField("AssetParams").addValue().addField("URL"),
+ typePath{}.addField("AppLocalStates").addValue().addField("KeyValue").addMapKey(),
+ typePath{}.addField("AppLocalStates").addValue().addField("KeyValue").addValue().addField("Bytes"),
+ typePath{}.addField("AppParams").addValue().addField("GlobalState").addMapKey(),
+ typePath{}.addField("AppParams").addValue().addField("GlobalState").addValue().addField("Bytes"),
+ }
+
+ seen := make(map[reflect.Type]bool)
+
+ checkReferencedTypes(seen, nil, []reflect.Type{typeToCheck}, makeTypeCheckFunction(t, exceptions, typeToCheck))
+}