aboutsummaryrefslogtreecommitdiff
path: root/ast
diff options
context:
space:
mode:
authorDimitri Sokolyuk <demon@dim13.org>2018-03-25 01:50:10 +0100
committerDimitri Sokolyuk <demon@dim13.org>2018-03-25 01:50:10 +0100
commite5ed6e13a4adbbe61317194af36c33c82b33c90f (patch)
tree790b5c954959ac852f0072f8bb2bd3137a4dac48 /ast
parent88efa3eb20001d1dc23e6b5a8413ea7adf10294e (diff)
lost chapter
Diffstat (limited to 'ast')
-rw-r--r--ast/ast.go25
-rw-r--r--ast/modify.go68
-rw-r--r--ast/modify_test.go152
3 files changed, 245 insertions, 0 deletions
diff --git a/ast/ast.go b/ast/ast.go
index 95b0986..1338f14 100644
--- a/ast/ast.go
+++ b/ast/ast.go
@@ -338,3 +338,28 @@ func (hl *HashLiteral) String() string {
return out.String()
}
+
+type MacroLiteral struct {
+ Token token.Token // The 'macro' token
+ Parameters []*Identifier
+ Body *BlockStatement
+}
+
+func (ml *MacroLiteral) expressionNode() {}
+func (ml *MacroLiteral) TokenLiteral() string { return ml.Token.Literal }
+func (ml *MacroLiteral) String() string {
+ var out bytes.Buffer
+
+ params := []string{}
+ for _, p := range ml.Parameters {
+ params = append(params, p.String())
+ }
+
+ out.WriteString(ml.TokenLiteral())
+ out.WriteString("(")
+ out.WriteString(strings.Join(params, ", "))
+ out.WriteString(") ")
+ out.WriteString(ml.Body.String())
+
+ return out.String()
+}
diff --git a/ast/modify.go b/ast/modify.go
new file mode 100644
index 0000000..60567bf
--- /dev/null
+++ b/ast/modify.go
@@ -0,0 +1,68 @@
+package ast
+
+type ModifierFunc func(Node) Node
+
+func Modify(node Node, modifier ModifierFunc) Node {
+ switch node := node.(type) {
+
+ case *Program:
+ for i, statement := range node.Statements {
+ node.Statements[i], _ = Modify(statement, modifier).(Statement)
+ }
+
+ case *ExpressionStatement:
+ node.Expression, _ = Modify(node.Expression, modifier).(Expression)
+
+ case *InfixExpression:
+ node.Left, _ = Modify(node.Left, modifier).(Expression)
+ node.Right, _ = Modify(node.Right, modifier).(Expression)
+
+ case *PrefixExpression:
+ node.Right, _ = Modify(node.Right, modifier).(Expression)
+
+ case *IndexExpression:
+ node.Left, _ = Modify(node.Left, modifier).(Expression)
+ node.Index, _ = Modify(node.Index, modifier).(Expression)
+
+ case *IfExpression:
+ node.Condition, _ = Modify(node.Condition, modifier).(Expression)
+ node.Consequence, _ = Modify(node.Consequence, modifier).(*BlockStatement)
+ if node.Alternative != nil {
+ node.Alternative, _ = Modify(node.Alternative, modifier).(*BlockStatement)
+ }
+
+ case *BlockStatement:
+ for i, _ := range node.Statements {
+ node.Statements[i], _ = Modify(node.Statements[i], modifier).(Statement)
+ }
+
+ case *ReturnStatement:
+ node.ReturnValue, _ = Modify(node.ReturnValue, modifier).(Expression)
+
+ case *LetStatement:
+ node.Value, _ = Modify(node.Value, modifier).(Expression)
+
+ case *FunctionLiteral:
+ for i, _ := range node.Parameters {
+ node.Parameters[i], _ = Modify(node.Parameters[i], modifier).(*Identifier)
+ }
+ node.Body, _ = Modify(node.Body, modifier).(*BlockStatement)
+
+ case *ArrayLiteral:
+ for i, _ := range node.Elements {
+ node.Elements[i], _ = Modify(node.Elements[i], modifier).(Expression)
+ }
+
+ case *HashLiteral:
+ newPairs := make(map[Expression]Expression)
+ for key, val := range node.Pairs {
+ newKey, _ := Modify(key, modifier).(Expression)
+ newVal, _ := Modify(val, modifier).(Expression)
+ newPairs[newKey] = newVal
+ }
+ node.Pairs = newPairs
+
+ }
+
+ return modifier(node)
+}
diff --git a/ast/modify_test.go b/ast/modify_test.go
new file mode 100644
index 0000000..6b5019f
--- /dev/null
+++ b/ast/modify_test.go
@@ -0,0 +1,152 @@
+package ast
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestModify(t *testing.T) {
+ one := func() Expression { return &IntegerLiteral{Value: 1} }
+ two := func() Expression { return &IntegerLiteral{Value: 2} }
+
+ turnOneIntoTwo := func(node Node) Node {
+ integer, ok := node.(*IntegerLiteral)
+ if !ok {
+ return node
+ }
+
+ if integer.Value != 1 {
+ return node
+ }
+
+ integer.Value = 2
+ return integer
+ }
+
+ tests := []struct {
+ input Node
+ expected Node
+ }{
+ {
+ one(),
+ two(),
+ },
+ {
+ &Program{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ &Program{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ },
+
+ {
+ &InfixExpression{Left: one(), Operator: "+", Right: two()},
+ &InfixExpression{Left: two(), Operator: "+", Right: two()},
+ },
+ {
+ &InfixExpression{Left: two(), Operator: "+", Right: one()},
+ &InfixExpression{Left: two(), Operator: "+", Right: two()},
+ },
+ {
+ &PrefixExpression{Operator: "-", Right: one()},
+ &PrefixExpression{Operator: "-", Right: two()},
+ },
+ {
+ &IndexExpression{Left: one(), Index: one()},
+ &IndexExpression{Left: two(), Index: two()},
+ },
+ {
+ &IfExpression{
+ Condition: one(),
+ Consequence: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ Alternative: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ },
+ &IfExpression{
+ Condition: two(),
+ Consequence: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ Alternative: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ },
+ },
+ {
+ &ReturnStatement{ReturnValue: one()},
+ &ReturnStatement{ReturnValue: two()},
+ },
+ {
+ &LetStatement{Value: one()},
+ &LetStatement{Value: two()},
+ },
+ {
+ &FunctionLiteral{
+ Parameters: []*Identifier{},
+ Body: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ },
+ &FunctionLiteral{
+ Parameters: []*Identifier{},
+ Body: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ },
+ },
+ {
+ &ArrayLiteral{Elements: []Expression{one(), one()}},
+ &ArrayLiteral{Elements: []Expression{two(), two()}},
+ },
+ }
+
+ for _, tt := range tests {
+ modified := Modify(tt.input, turnOneIntoTwo)
+
+ equal := reflect.DeepEqual(modified, tt.expected)
+ if !equal {
+ t.Errorf("not equal. got=%#v, want=%#v",
+ modified, tt.expected)
+ }
+ }
+
+ hashLiteral := &HashLiteral{
+ Pairs: map[Expression]Expression{
+ one(): one(),
+ one(): one(),
+ },
+ }
+
+ Modify(hashLiteral, turnOneIntoTwo)
+
+ for key, val := range hashLiteral.Pairs {
+ key, _ := key.(*IntegerLiteral)
+ if key.Value != 2 {
+ t.Errorf("value is not %d, got=%d", 2, key.Value)
+ }
+ val, _ := val.(*IntegerLiteral)
+ if val.Value != 2 {
+ t.Errorf("value is not %d, got=%d", 2, val.Value)
+ }
+ }
+}