aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ast/ast.go25
-rw-r--r--ast/modify.go68
-rw-r--r--ast/modify_test.go152
-rw-r--r--evaluator/evaluator.go3
-rw-r--r--evaluator/macro_expansion.go123
-rw-r--r--evaluator/macro_expansion_test.go123
-rw-r--r--evaluator/quote_unquote.go68
-rw-r--r--evaluator/quote_unquote_test.go117
-rw-r--r--lexer/lexer_test.go14
-rw-r--r--object/object.go36
-rw-r--r--parser/parser.go19
-rw-r--r--parser/parser_test.go47
-rw-r--r--repl/repl.go6
-rw-r--r--token/token.go2
14 files changed, 802 insertions, 1 deletions
diff --git a/ast/ast.go b/ast/ast.go
index 95b0986..1338f14 100644
--- a/ast/ast.go
+++ b/ast/ast.go
@@ -338,3 +338,28 @@ func (hl *HashLiteral) String() string {
return out.String()
}
+
+type MacroLiteral struct {
+ Token token.Token // The 'macro' token
+ Parameters []*Identifier
+ Body *BlockStatement
+}
+
+func (ml *MacroLiteral) expressionNode() {}
+func (ml *MacroLiteral) TokenLiteral() string { return ml.Token.Literal }
+func (ml *MacroLiteral) String() string {
+ var out bytes.Buffer
+
+ params := []string{}
+ for _, p := range ml.Parameters {
+ params = append(params, p.String())
+ }
+
+ out.WriteString(ml.TokenLiteral())
+ out.WriteString("(")
+ out.WriteString(strings.Join(params, ", "))
+ out.WriteString(") ")
+ out.WriteString(ml.Body.String())
+
+ return out.String()
+}
diff --git a/ast/modify.go b/ast/modify.go
new file mode 100644
index 0000000..60567bf
--- /dev/null
+++ b/ast/modify.go
@@ -0,0 +1,68 @@
+package ast
+
+type ModifierFunc func(Node) Node
+
+func Modify(node Node, modifier ModifierFunc) Node {
+ switch node := node.(type) {
+
+ case *Program:
+ for i, statement := range node.Statements {
+ node.Statements[i], _ = Modify(statement, modifier).(Statement)
+ }
+
+ case *ExpressionStatement:
+ node.Expression, _ = Modify(node.Expression, modifier).(Expression)
+
+ case *InfixExpression:
+ node.Left, _ = Modify(node.Left, modifier).(Expression)
+ node.Right, _ = Modify(node.Right, modifier).(Expression)
+
+ case *PrefixExpression:
+ node.Right, _ = Modify(node.Right, modifier).(Expression)
+
+ case *IndexExpression:
+ node.Left, _ = Modify(node.Left, modifier).(Expression)
+ node.Index, _ = Modify(node.Index, modifier).(Expression)
+
+ case *IfExpression:
+ node.Condition, _ = Modify(node.Condition, modifier).(Expression)
+ node.Consequence, _ = Modify(node.Consequence, modifier).(*BlockStatement)
+ if node.Alternative != nil {
+ node.Alternative, _ = Modify(node.Alternative, modifier).(*BlockStatement)
+ }
+
+ case *BlockStatement:
+ for i, _ := range node.Statements {
+ node.Statements[i], _ = Modify(node.Statements[i], modifier).(Statement)
+ }
+
+ case *ReturnStatement:
+ node.ReturnValue, _ = Modify(node.ReturnValue, modifier).(Expression)
+
+ case *LetStatement:
+ node.Value, _ = Modify(node.Value, modifier).(Expression)
+
+ case *FunctionLiteral:
+ for i, _ := range node.Parameters {
+ node.Parameters[i], _ = Modify(node.Parameters[i], modifier).(*Identifier)
+ }
+ node.Body, _ = Modify(node.Body, modifier).(*BlockStatement)
+
+ case *ArrayLiteral:
+ for i, _ := range node.Elements {
+ node.Elements[i], _ = Modify(node.Elements[i], modifier).(Expression)
+ }
+
+ case *HashLiteral:
+ newPairs := make(map[Expression]Expression)
+ for key, val := range node.Pairs {
+ newKey, _ := Modify(key, modifier).(Expression)
+ newVal, _ := Modify(val, modifier).(Expression)
+ newPairs[newKey] = newVal
+ }
+ node.Pairs = newPairs
+
+ }
+
+ return modifier(node)
+}
diff --git a/ast/modify_test.go b/ast/modify_test.go
new file mode 100644
index 0000000..6b5019f
--- /dev/null
+++ b/ast/modify_test.go
@@ -0,0 +1,152 @@
+package ast
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestModify(t *testing.T) {
+ one := func() Expression { return &IntegerLiteral{Value: 1} }
+ two := func() Expression { return &IntegerLiteral{Value: 2} }
+
+ turnOneIntoTwo := func(node Node) Node {
+ integer, ok := node.(*IntegerLiteral)
+ if !ok {
+ return node
+ }
+
+ if integer.Value != 1 {
+ return node
+ }
+
+ integer.Value = 2
+ return integer
+ }
+
+ tests := []struct {
+ input Node
+ expected Node
+ }{
+ {
+ one(),
+ two(),
+ },
+ {
+ &Program{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ &Program{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ },
+
+ {
+ &InfixExpression{Left: one(), Operator: "+", Right: two()},
+ &InfixExpression{Left: two(), Operator: "+", Right: two()},
+ },
+ {
+ &InfixExpression{Left: two(), Operator: "+", Right: one()},
+ &InfixExpression{Left: two(), Operator: "+", Right: two()},
+ },
+ {
+ &PrefixExpression{Operator: "-", Right: one()},
+ &PrefixExpression{Operator: "-", Right: two()},
+ },
+ {
+ &IndexExpression{Left: one(), Index: one()},
+ &IndexExpression{Left: two(), Index: two()},
+ },
+ {
+ &IfExpression{
+ Condition: one(),
+ Consequence: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ Alternative: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ },
+ &IfExpression{
+ Condition: two(),
+ Consequence: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ Alternative: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ },
+ },
+ {
+ &ReturnStatement{ReturnValue: one()},
+ &ReturnStatement{ReturnValue: two()},
+ },
+ {
+ &LetStatement{Value: one()},
+ &LetStatement{Value: two()},
+ },
+ {
+ &FunctionLiteral{
+ Parameters: []*Identifier{},
+ Body: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: one()},
+ },
+ },
+ },
+ &FunctionLiteral{
+ Parameters: []*Identifier{},
+ Body: &BlockStatement{
+ Statements: []Statement{
+ &ExpressionStatement{Expression: two()},
+ },
+ },
+ },
+ },
+ {
+ &ArrayLiteral{Elements: []Expression{one(), one()}},
+ &ArrayLiteral{Elements: []Expression{two(), two()}},
+ },
+ }
+
+ for _, tt := range tests {
+ modified := Modify(tt.input, turnOneIntoTwo)
+
+ equal := reflect.DeepEqual(modified, tt.expected)
+ if !equal {
+ t.Errorf("not equal. got=%#v, want=%#v",
+ modified, tt.expected)
+ }
+ }
+
+ hashLiteral := &HashLiteral{
+ Pairs: map[Expression]Expression{
+ one(): one(),
+ one(): one(),
+ },
+ }
+
+ Modify(hashLiteral, turnOneIntoTwo)
+
+ for key, val := range hashLiteral.Pairs {
+ key, _ := key.(*IntegerLiteral)
+ if key.Value != 2 {
+ t.Errorf("value is not %d, got=%d", 2, key.Value)
+ }
+ val, _ := val.(*IntegerLiteral)
+ if val.Value != 2 {
+ t.Errorf("value is not %d, got=%d", 2, val.Value)
+ }
+ }
+}
diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go
index a6b21cc..6ea4b01 100644
--- a/evaluator/evaluator.go
+++ b/evaluator/evaluator.go
@@ -82,6 +82,9 @@ func Eval(node ast.Node, env *object.Environment) object.Object {
return &object.Function{Parameters: params, Env: env, Body: body}
case *ast.CallExpression:
+ if node.Function.TokenLiteral() == "quote" {
+ return quote(node.Arguments[0], env)
+ }
function := Eval(node.Function, env)
if isError(function) {
return function
diff --git a/evaluator/macro_expansion.go b/evaluator/macro_expansion.go
new file mode 100644
index 0000000..1fdac0a
--- /dev/null
+++ b/evaluator/macro_expansion.go
@@ -0,0 +1,123 @@
+package evaluator
+
+import (
+ "monkey/ast"
+ "monkey/object"
+)
+
+func DefineMacros(program *ast.Program, env *object.Environment) {
+ definitions := []int{}
+
+ for i, statement := range program.Statements {
+ if isMacroDefinition(statement) {
+ addMacro(statement, env)
+ definitions = append(definitions, i)
+ }
+ }
+
+ for i := len(definitions) - 1; i >= 0; i = i - 1 {
+ definitionIndex := definitions[i]
+ program.Statements = append(
+ program.Statements[:definitionIndex],
+ program.Statements[definitionIndex+1:]...,
+ )
+ }
+}
+
+func isMacroDefinition(node ast.Statement) bool {
+ letStatement, ok := node.(*ast.LetStatement)
+ if !ok {
+ return false
+ }
+
+ _, ok = letStatement.Value.(*ast.MacroLiteral)
+ if !ok {
+ return false
+ }
+
+ return true
+}
+
+func addMacro(stmt ast.Statement, env *object.Environment) {
+ letStatement, _ := stmt.(*ast.LetStatement)
+ macroLiteral, _ := letStatement.Value.(*ast.MacroLiteral)
+
+ macro := &object.Macro{
+ Parameters: macroLiteral.Parameters,
+ Env: env,
+ Body: macroLiteral.Body,
+ }
+
+ env.Set(letStatement.Name.Value, macro)
+}
+
+func ExpandMacros(program ast.Node, env *object.Environment) ast.Node {
+ return ast.Modify(program, func(node ast.Node) ast.Node {
+ callExpression, ok := node.(*ast.CallExpression)
+ if !ok {
+ return node
+ }
+
+ macro, ok := isMacroCall(callExpression, env)
+ if !ok {
+ return node
+ }
+
+ args := quoteArgs(callExpression)
+ evalEnv := extendMacroEnv(macro, args)
+
+ evaluated := Eval(macro.Body, evalEnv)
+
+ quote, ok := evaluated.(*object.Quote)
+ if !ok {
+ panic("we only support returning AST-nodes from macros")
+ }
+
+ return quote.Node
+ })
+}
+
+func isMacroCall(
+ exp *ast.CallExpression,
+ env *object.Environment,
+) (*object.Macro, bool) {
+ identifier, ok := exp.Function.(*ast.Identifier)
+ if !ok {
+ return nil, false
+ }
+
+ obj, ok := env.Get(identifier.Value)
+ if !ok {
+ return nil, false
+ }
+
+ macro, ok := obj.(*object.Macro)
+ if !ok {
+ return nil, false
+ }
+
+ return macro, true
+}
+
+func quoteArgs(exp *ast.CallExpression) []*object.Quote {
+ args := []*object.Quote{}
+
+ for _, a := range exp.Arguments {
+ args = append(args, &object.Quote{Node: a})
+ }
+
+ return args
+}
+
+func extendMacroEnv(
+ macro *object.Macro,
+ args []*object.Quote,
+) *object.Environment {
+ extended := object.NewEnclosedEnvironment(macro.Env)
+
+ for paramIdx, param := range macro.Parameters {
+ extended.Set(param.Value, args[paramIdx])
+ }
+
+ return extended
+}
diff --git a/evaluator/macro_expansion_test.go b/evaluator/macro_expansion_test.go
new file mode 100644
index 0000000..4a67f35
--- /dev/null
+++ b/evaluator/macro_expansion_test.go
@@ -0,0 +1,123 @@
+package evaluator
+
+import (
+ "monkey/ast"
+ "monkey/lexer"
+ "monkey/object"
+ "monkey/parser"
+ "testing"
+)
+
+func TestDefineMacros(t *testing.T) {
+ input := `
+let number = 1;
+let function = fn(x, y) { x + y };
+let mymacro = macro(x, y) { x + y; };
+`
+
+ env := object.NewEnvironment()
+ program := testParseProgram(input)
+
+ DefineMacros(program, env)
+
+ if len(program.Statements) != 2 {
+ t.Fatalf("Wrong number of statements. got=%d",
+ len(program.Statements))
+ }
+
+ _, ok := env.Get("number")
+ if ok {
+ t.Fatalf("number should not be defined")
+ }
+ _, ok = env.Get("function")
+ if ok {
+ t.Fatalf("function should not be defined")
+ }
+
+ obj, ok := env.Get("mymacro")
+ if !ok {
+ t.Fatalf("macro not in environment.")
+ }
+
+ macro, ok := obj.(*object.Macro)
+ if !ok {
+ t.Fatalf("object is not Macro. got=%T (%+v)", obj, obj)
+ }
+
+ if len(macro.Parameters) != 2 {
+ t.Fatalf("Wrong number of macro parameters. got=%d",
+ len(macro.Parameters))
+ }
+
+ if macro.Parameters[0].String() != "x" {
+ t.Fatalf("parameter is not 'x'. got=%q", macro.Parameters[0])
+ }
+ if macro.Parameters[1].String() != "y" {
+ t.Fatalf("parameter is not 'y'. got=%q", macro.Parameters[1])
+ }
+
+ expectedBody := "(x + y)"
+
+ if macro.Body.String() != expectedBody {
+ t.Fatalf("body is not %q. got=%q", expectedBody, macro.Body.String())
+ }
+}
+
+func testParseProgram(input string) *ast.Program {
+ l := lexer.New(input)
+ p := parser.New(l)
+ return p.ParseProgram()
+}
+
+func TestExpandMacros(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {
+ `
+ let infixExpression = macro() { quote(1 + 2); };
+
+ infixExpression();
+ `,
+ `(1 + 2)`,
+ },
+ {
+ `
+ let reverse = macro(a, b) { quote(unquote(b) - unquote(a)); };
+
+ reverse(2 + 2, 10 - 5);
+ `,
+ `(10 - 5) - (2 + 2)`,
+ },
+ {
+ `
+ let unless = macro(condition, consequence, alternative) {
+ quote(if (!(unquote(condition))) {
+ unquote(consequence);
+ } else {
+ unquote(alternative);
+ });
+ };
+
+ unless(10 > 5, puts("not greater"), puts("greater"));
+ `,
+
+ `if (!(10 > 5)) { puts("not greater") } else { puts("greater") }`,
+ },
+ }
+
+ for _, tt := range tests {
+ expected := testParseProgram(tt.expected)
+ program := testParseProgram(tt.input)
+
+ env := object.NewEnvironment()
+ DefineMacros(program, env)
+ expanded := ExpandMacros(program, env)
+
+ if expanded.String() != expected.String() {
+ t.Errorf("not equal. want=%q, got=%q",
+ expected.String(), expanded.String())
+ }
+ }
+}
diff --git a/evaluator/quote_unquote.go b/evaluator/quote_unquote.go
new file mode 100644
index 0000000..d3521f9
--- /dev/null
+++ b/evaluator/quote_unquote.go
@@ -0,0 +1,68 @@
+package evaluator
+
+import (
+ "fmt"
+ "monkey/ast"
+ "monkey/object"
+ "monkey/token"
+)
+
+func quote(node ast.Node, env *object.Environment) object.Object {
+ node = evalUnquoteCalls(node, env)
+ return &object.Quote{Node: node}
+}
+
+func evalUnquoteCalls(quoted ast.Node, env *object.Environment) ast.Node {
+ return ast.Modify(quoted, func(node ast.Node) ast.Node {
+ if !isUnquoteCall(node) {
+ return node
+ }
+
+ call, ok := node.(*ast.CallExpression)
+ if !ok {
+ return node
+ }
+
+ if len(call.Arguments) != 1 {
+ return node
+ }
+
+ unquoted := Eval(call.Arguments[0], env)
+ return convertObjectToASTNode(unquoted)
+ })
+}
+
+func isUnquoteCall(node ast.Node) bool {
+ callExpression, ok := node.(*ast.CallExpression)
+ if !ok {
+ return false
+ }
+
+ return callExpression.Function.TokenLiteral() == "unquote"
+}
+
+func convertObjectToASTNode(obj object.Object) ast.Node {
+ switch obj := obj.(type) {
+ case *object.Integer:
+ t := token.Token{
+ Type: token.INT,
+ Literal: fmt.Sprintf("%d", obj.Value),
+ }
+ return &ast.IntegerLiteral{Token: t, Value: obj.Value}
+
+ case *object.Boolean:
+ var t token.Token
+ if obj.Value {
+ t = token.Token{Type: token.TRUE, Literal: "true"}
+ } else {
+ t = token.Token{Type: token.FALSE, Literal: "false"}
+ }
+ return &ast.Boolean{Token: t, Value: obj.Value}
+
+ case *object.Quote:
+ return obj.Node
+
+ default:
+ return nil
+ }
+}
diff --git a/evaluator/quote_unquote_test.go b/evaluator/quote_unquote_test.go
new file mode 100644
index 0000000..2859540
--- /dev/null
+++ b/evaluator/quote_unquote_test.go
@@ -0,0 +1,117 @@
+package evaluator
+
+import (
+ "monkey/object"
+ "testing"
+)
+
+func TestQuote(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {
+ `quote(5)`,
+ `5`,
+ },
+ {
+ `quote(5 + 8)`,
+ `(5 + 8)`,
+ },
+ {
+ `quote(foobar)`,
+ `foobar`,
+ },
+ {
+ `quote(foobar + barfoo)`,
+ `(foobar + barfoo)`,
+ },
+ }
+
+ for _, tt := range tests {
+ evaluated := testEval(tt.input)
+ quote, ok := evaluated.(*object.Quote)
+ if !ok {
+ t.Fatalf("expected *object.Quote. got=%T (%+v)",
+ evaluated, evaluated)
+ }
+
+ if quote.Node == nil {
+ t.Fatalf("quote.Node is nil")
+ }
+
+ if quote.Node.String() != tt.expected {
+ t.Errorf("not equal. got=%q, want=%q",
+ quote.Node.String(), tt.expected)
+ }
+ }
+}
+
+func TestQuoteUnquote(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {
+ `quote(unquote(4))`,
+ `4`,
+ },
+ {
+ `quote(unquote(4 + 4))`,
+ `8`,
+ },
+ {
+ `quote(8 + unquote(4 + 4))`,
+ `(8 + 8)`,
+ },
+ {
+ `quote(unquote(4 + 4) + 8)`,
+ `(8 + 8)`,
+ },
+ {
+ `let foobar = 8;
+ quote(foobar)`,
+ `foobar`,
+ },
+ {
+ `let foobar = 8;
+ quote(unquote(foobar))`,
+ `8`,
+ },
+ {
+ `quote(unquote(true))`,
+ `true`,
+ },
+ {
+ `quote(unquote(true == false))`,
+ `false`,
+ },
+ {
+ `quote(unquote(quote(4 + 4)))`,
+ `(4 + 4)`,
+ },
+ {
+ `let quotedInfixExpression = quote(4 + 4);
+ quote(unquote(4 + 4) + unquote(quotedInfixExpression))`,
+ `(8 + (4 + 4))`,
+ },
+ }
+
+ for _, tt := range tests {
+ evaluated := testEval(tt.input)
+ quote, ok := evaluated.(*object.Quote)
+ if !ok {
+ t.Fatalf("expected *object.Quote. got=%T (%+v)",
+ evaluated, evaluated)
+ }
+
+ if quote.Node == nil {
+ t.Fatalf("quote.Node is nil")
+ }
+
+ if quote.Node.String() != tt.expected {
+ t.Errorf("not equal. got=%q, want=%q",
+ quote.Node.String(), tt.expected)
+ }
+ }
+}
diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go
index e4e64a4..f63d7c3 100644
--- a/lexer/lexer_test.go
+++ b/lexer/lexer_test.go
@@ -30,6 +30,7 @@ if (5 < 10) {
"foo bar"
[1, 2];
{"foo": "bar"}
+macro(x, y) { x + y; };
`
tests := []struct {
@@ -122,6 +123,19 @@ if (5 < 10) {
{token.COLON, ":"},
{token.STRING, "bar"},
{token.RBRACE, "}"},
+ {token.MACRO, "macro"},
+ {token.LPAREN, "("},
+ {token.IDENT, "x"},
+ {token.COMMA, ","},
+ {token.IDENT, "y"},
+ {token.RPAREN, ")"},
+ {token.LBRACE, "{"},
+ {token.IDENT, "x"},
+ {token.PLUS, "+"},
+ {token.IDENT, "y"},
+ {token.SEMICOLON, ";"},
+ {token.RBRACE, "}"},
+ {token.SEMICOLON, ";"},
{token.EOF, ""},
}
diff --git a/object/object.go b/object/object.go
index daafcac..7faf808 100644
--- a/object/object.go
+++ b/object/object.go
@@ -28,6 +28,8 @@ const (
ARRAY_OBJ = "ARRAY"
HASH_OBJ = "HASH"
+ QUOTE_OBJ = "QUOTE"
+ MACRO_OBJ = "MACRO"
)
type HashKey struct {
@@ -181,3 +183,37 @@ func (h *Hash) Inspect() string {
return out.String()
}
+
+type Quote struct {
+ Node ast.Node
+}
+
+func (q *Quote) Type() ObjectType { return QUOTE_OBJ }
+func (q *Quote) Inspect() string {
+ return "QUOTE(" + q.Node.String() + ")"
+}
+
+type Macro struct {
+ Parameters []*ast.Identifier
+ Body *ast.BlockStatement
+ Env *Environment
+}
+
+func (m *Macro) Type() ObjectType { return MACRO_OBJ }
+func (m *Macro) Inspect() string {
+ var out bytes.Buffer
+
+ params := []string{}
+ for _, p := range m.Parameters {
+ params = append(params, p.String())
+ }
+
+ out.WriteString("macro")
+ out.WriteString("(")
+ out.WriteString(strings.Join(params, ", "))
+ out.WriteString(") {\n")
+ out.WriteString(m.Body.String())
+ out.WriteString("\n}")
+
+ return out.String()
+}
diff --git a/parser/parser.go b/parser/parser.go
index 300f2d7..4795761 100644
--- a/parser/parser.go
+++ b/parser/parser.go
@@ -68,6 +68,7 @@ func New(l *lexer.Lexer) *Parser {
p.registerPrefix(token.FUNCTION, p.parseFunctionLiteral)
p.registerPrefix(token.LBRACKET, p.parseArrayLiteral)
p.registerPrefix(token.LBRACE, p.parseHashLiteral)
+ p.registerPrefix(token.MACRO, p.parseMacroLiteral)
p.infixParseFns = make(map[token.TokenType]infixParseFn)
p.registerInfix(token.PLUS, p.parseInfixExpression)
@@ -482,6 +483,24 @@ func (p *Parser) parseHashLiteral() ast.Expression {
return hash
}
+func (p *Parser) parseMacroLiteral() ast.Expression {
+ lit := &ast.MacroLiteral{Token: p.curToken}
+
+ if !p.expectPeek(token.LPAREN) {
+ return nil
+ }
+
+ lit.Parameters = p.parseFunctionParameters()
+
+ if !p.expectPeek(token.LBRACE) {
+ return nil
+ }
+
+ lit.Body = p.parseBlockStatement()
+
+ return lit
+}
+
func (p *Parser) registerPrefix(tokenType token.TokenType, fn prefixParseFn) {
p.prefixParseFns[tokenType] = fn
}
diff --git a/parser/parser_test.go b/parser/parser_test.go
index 7fe5268..df040ad 100644
--- a/parser/parser_test.go
+++ b/parser/parser_test.go
@@ -1088,3 +1088,50 @@ func checkParserErrors(t *testing.T, p *Parser) {
}
t.FailNow()
}
+
+func TestMacroLiteralParsing(t *testing.T) {
+ input := `macro(x, y) { x + y; }`
+
+ l := lexer.New(input)
+ p := New(l)
+ program := p.ParseProgram()
+ checkParserErrors(t, p)
+
+ if len(program.Statements) != 1 {
+ t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
+ 1, len(program.Statements))
+ }
+
+ stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
+ if !ok {
+ t.Fatalf("statement is not ast.ExpressionStatement. got=%T",
+ program.Statements[0])
+ }
+
+ macro, ok := stmt.Expression.(*ast.MacroLiteral)
+ if !ok {
+ t.Fatalf("stmt.Expression is not ast.MacroLiteral. got=%T",
+ stmt.Expression)
+ }
+
+ if len(macro.Parameters) != 2 {
+ t.Fatalf("macro literal parameters wrong. want 2, got=%d\n",
+ len(macro.Parameters))
+ }
+
+ testLiteralExpression(t, macro.Parameters[0], "x")
+ testLiteralExpression(t, macro.Parameters[1], "y")
+
+ if len(macro.Body.Statements) != 1 {
+ t.Fatalf("macro.Body.Statements has not 1 statements. got=%d\n",
+ len(macro.Body.Statements))
+ }
+
+ bodyStmt, ok := macro.Body.Statements[0].(*ast.ExpressionStatement)
+ if !ok {
+ t.Fatalf("macro body stmt is not ast.ExpressionStatement. got=%T",
+ macro.Body.Statements[0])
+ }
+
+ testInfixExpression(t, bodyStmt.Expression, "x", "+", "y")
+}
diff --git a/repl/repl.go b/repl/repl.go
index 3d75f8a..e267e1d 100644
--- a/repl/repl.go
+++ b/repl/repl.go
@@ -16,6 +16,7 @@ const PROMPT = ">> "
func Start(in io.Reader, out io.Writer) {
scanner := bufio.NewScanner(in)
env := object.NewEnvironment()
+ macroEnv := object.NewEnvironment()
for {
fmt.Print(PROMPT)
@@ -34,7 +35,10 @@ func Start(in io.Reader, out io.Writer) {
continue
}
- evaluated := evaluator.Eval(program, env)
+ evaluator.DefineMacros(program, macroEnv)
+ expanded := evaluator.ExpandMacros(program, macroEnv)
+
+ evaluated := evaluator.Eval(expanded, env)
if evaluated != nil {
io.WriteString(out, evaluated.Inspect())
io.WriteString(out, "\n")
diff --git a/token/token.go b/token/token.go
index 3d2d2f7..c7bf73a 100644
--- a/token/token.go
+++ b/token/token.go
@@ -45,6 +45,7 @@ const (
IF = "IF"
ELSE = "ELSE"
RETURN = "RETURN"
+ MACRO = "MACRO"
)
type Token struct {
@@ -60,6 +61,7 @@ var keywords = map[string]TokenType{
"if": IF,
"else": ELSE,
"return": RETURN,
+ "macro": MACRO,
}
func LookupIdent(ident string) TokenType {