aboutsummaryrefslogtreecommitdiff
path: root/ast/modify_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'ast/modify_test.go')
-rw-r--r--ast/modify_test.go152
1 files changed, 152 insertions, 0 deletions
diff --git a/ast/modify_test.go b/ast/modify_test.go
new file mode 100644
index 0000000..6b5019f
--- /dev/null
+++ b/ast/modify_test.go
@@ -0,0 +1,152 @@
+package ast
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestModify(t *testing.T) {
+ one := func() Expression { return &IntegerLiteral{Value: 1} }
+ two := func() Expression { return &IntegerLiteral{Value: 2} }
+
+ turnOneIntoTwo := func(node Node) Node {
+ integer, ok := node.(*IntegerLiteral)
+ if !ok {
+ return node
+ }
+
+ if integer.Value != 1 {
+ return node
+ }
+
+ integer.Value = 2
+ return integer
+ }
+
+ tests := []struct {
+ input Node
+ expected Node
+ }{
+ {
+ one(),
+ two(),
+ },
+ {
+ &Program{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ &Program{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ },
+
+ {
+ &InfixExpression{Left: one(), Operator: "+", Right: two()},
+ &InfixExpression{Left: two(), Operator: "+", Right: two()},
+ },
+ {
+ &InfixExpression{Left: two(), Operator: "+", Right: one()},
+ &InfixExpression{Left: two(), Operator: "+", Right: two()},
+ },
+ {
+ &PrefixExpression{Operator: "-", Right: one()},
+ &PrefixExpression{Operator: "-", Right: two()},
+ },
+ {
+ &IndexExpression{Left: one(), Index: one()},
+ &IndexExpression{Left: two(), Index: two()},
+ },
+ {
+ &IfExpression{
+ Condition: one(),
+ Consequence: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ Alternative: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ },
+ &IfExpression{
+ Condition: two(),
+ Consequence: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ Alternative: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ },
+ },
+ {
+ &ReturnStatement{ReturnValue: one()},
+ &ReturnStatement{ReturnValue: two()},
+ },
+ {
+ &LetStatement{Value: one()},
+ &LetStatement{Value: two()},
+ },
+ {
+ &FunctionLiteral{
+ Parameters: []*Identifier{},
+ Body: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ },
+ &FunctionLiteral{
+ Parameters: []*Identifier{},
+ Body: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ },
+ },
+ {
+ &ArrayLiteral{Elements: []Expression{one(), one()}},
+ &ArrayLiteral{Elements: []Expression{two(), two()}},
+ },
+ }
+
+ for _, tt := range tests {
+ modified := Modify(tt.input, turnOneIntoTwo)
+
+ equal := reflect.DeepEqual(modified, tt.expected)
+ if !equal {
+ t.Errorf("not equal. got=%#v, want=%#v",
+ modified, tt.expected)
+ }
+ }
+
+ hashLiteral := &HashLiteral{
+ Pairs: map[Expression]Expression{
+ one(): one(),
+ one(): one(),
+ },
+ }
+
+ Modify(hashLiteral, turnOneIntoTwo)
+
+ for key, val := range hashLiteral.Pairs {
+ key, _ := key.(*IntegerLiteral)
+ if key.Value != 2 {
+ t.Errorf("value is not %d, got=%d", 2, key.Value)
+ }
+ val, _ := val.(*IntegerLiteral)
+ if val.Value != 2 {
+ t.Errorf("value is not %d, got=%d", 2, val.Value)
+ }
+ }
+}