From 523f40968b9c1a23da1f4a1c2f125197d7611fef Mon Sep 17 00:00:00 2001 From: Dimitri Sokolyuk Date: Sat, 20 May 2017 16:28:10 +0200 Subject: 04 --- ast/ast.go | 73 ++++++++++++ evaluator/builtins.go | 117 +++++++++++++++++++ evaluator/evaluator.go | 137 ++++++++++++++++++++-- evaluator/evaluator_test.go | 278 ++++++++++++++++++++++++++++++++++++++++++++ lexer/lexer.go | 20 ++++ lexer/lexer_test.go | 17 +++ object/object.go | 97 ++++++++++++++++ object/object_test.go | 60 ++++++++++ parser/parser.go | 78 +++++++++++-- parser/parser_test.go | 266 ++++++++++++++++++++++++++++++++++++++++++ token/token.go | 16 ++- 11 files changed, 1134 insertions(+), 25 deletions(-) create mode 100644 evaluator/builtins.go create mode 100644 object/object_test.go diff --git a/ast/ast.go b/ast/ast.go index e8c133f..fb30b05 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -264,3 +264,76 @@ func (ce *CallExpression) String() string { return out.String() } + +type StringLiteral struct { + Token token.Token + Value string +} + +func (sl *StringLiteral) expressionNode() {} +func (sl *StringLiteral) TokenLiteral() string { return sl.Token.Literal } +func (sl *StringLiteral) String() string { return sl.Token.Literal } + +type ArrayLiteral struct { + Token token.Token // the '[' token + Elements []Expression +} + +func (al *ArrayLiteral) expressionNode() {} +func (al *ArrayLiteral) TokenLiteral() string { return al.Token.Literal } +func (al *ArrayLiteral) String() string { + var out bytes.Buffer + + elements := []string{} + for _, el := range al.Elements { + elements = append(elements, el.String()) + } + + out.WriteString("[") + out.WriteString(strings.Join(elements, ", ")) + out.WriteString("]") + + return out.String() +} + +type IndexExpression struct { + Token token.Token // The [ token + Left Expression + Index Expression +} + +func (ie *IndexExpression) expressionNode() {} +func (ie *IndexExpression) TokenLiteral() string { return ie.Token.Literal } +func (ie *IndexExpression) String() string { + var out bytes.Buffer + + out.WriteString("(") + out.WriteString(ie.Left.String()) + out.WriteString("[") + out.WriteString(ie.Index.String()) + out.WriteString("])") + + return out.String() +} + +type HashLiteral struct { + Token token.Token // the '{' token + Pairs map[Expression]Expression +} + +func (hl *HashLiteral) expressionNode() {} +func (hl *HashLiteral) TokenLiteral() string { return hl.Token.Literal } +func (hl *HashLiteral) String() string { + var out bytes.Buffer + + pairs := []string{} + for key, value := range hl.Pairs { + pairs = append(pairs, key.String()+":"+value.String()) + } + + out.WriteString("{") + out.WriteString(strings.Join(pairs, ", ")) + out.WriteString("}") + + return out.String() +} diff --git a/evaluator/builtins.go b/evaluator/builtins.go new file mode 100644 index 0000000..68eadcd --- /dev/null +++ b/evaluator/builtins.go @@ -0,0 +1,117 @@ +package evaluator + +import ( + "fmt" + "monkey/object" +) + +var builtins = map[string]*object.Builtin{ + "len": &object.Builtin{Fn: func(args ...object.Object) object.Object { + if len(args) != 1 { + return newError("wrong number of arguments. got=%d, want=1", + len(args)) + } + + switch arg := args[0].(type) { + case *object.Array: + return &object.Integer{Value: int64(len(arg.Elements))} + case *object.String: + return &object.Integer{Value: int64(len(arg.Value))} + default: + return newError("argument to `len` not supported, got %s", + args[0].Type()) + } + }, + }, + "puts": &object.Builtin{ + Fn: func(args ...object.Object) object.Object { + for _, arg := range args { + fmt.Println(arg.Inspect()) + } + + return NULL + }, + }, + "first": &object.Builtin{ + Fn: func(args ...object.Object) object.Object { + if len(args) != 1 { + return newError("wrong number of arguments. got=%d, want=1", + len(args)) + } + if args[0].Type() != object.ARRAY_OBJ { + return newError("argument to `first` must be ARRAY, got %s", + args[0].Type()) + } + + arr := args[0].(*object.Array) + if len(arr.Elements) > 0 { + return arr.Elements[0] + } + + return NULL + }, + }, + "last": &object.Builtin{ + Fn: func(args ...object.Object) object.Object { + if len(args) != 1 { + return newError("wrong number of arguments. got=%d, want=1", + len(args)) + } + if args[0].Type() != object.ARRAY_OBJ { + return newError("argument to `last` must be ARRAY, got %s", + args[0].Type()) + } + + arr := args[0].(*object.Array) + length := len(arr.Elements) + if length > 0 { + return arr.Elements[length-1] + } + + return NULL + }, + }, + "rest": &object.Builtin{ + Fn: func(args ...object.Object) object.Object { + if len(args) != 1 { + return newError("wrong number of arguments. got=%d, want=1", + len(args)) + } + if args[0].Type() != object.ARRAY_OBJ { + return newError("argument to `rest` must be ARRAY, got %s", + args[0].Type()) + } + + arr := args[0].(*object.Array) + length := len(arr.Elements) + if length > 0 { + newElements := make([]object.Object, length-1, length-1) + copy(newElements, arr.Elements[1:length]) + return &object.Array{Elements: newElements} + } + + return NULL + }, + }, + "push": &object.Builtin{ + Fn: func(args ...object.Object) object.Object { + if len(args) != 2 { + return newError("wrong number of arguments. got=%d, want=1", + len(args)) + } + if args[0].Type() != object.ARRAY_OBJ { + return newError("argument to `push` must be ARRAY, got %s", + args[0].Type()) + } + + arr := args[0].(*object.Array) + length := len(arr.Elements) + + newElements := make([]object.Object, length+1, length+1) + copy(newElements, arr.Elements) + newElements[length] = args[1] + + return &object.Array{Elements: newElements} + }, + }, +} diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index c8fa09a..50b2ab5 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -43,6 +43,9 @@ func Eval(node ast.Node, env *object.Environment) object.Object { case *ast.IntegerLiteral: return &object.Integer{Value: node.Value} + case *ast.StringLiteral: + return &object.String{Value: node.Value} + case *ast.Boolean: return nativeBoolToBooleanObject(node.Value) @@ -89,6 +92,28 @@ func Eval(node ast.Node, env *object.Environment) object.Object { } return applyFunction(function, args) + + case *ast.ArrayLiteral: + elements := evalExpressions(node.Elements, env) + if len(elements) == 1 && isError(elements[0]) { + return elements[0] + } + return &object.Array{Elements: elements} + + case *ast.IndexExpression: + left := Eval(node.Left, env) + if isError(left) { + return left + } + index := Eval(node.Index, env) + if isError(index) { + return index + } + return evalIndexExpression(left, index) + + case *ast.HashLiteral: + return evalHashLiteral(node, env) + } return nil @@ -156,6 +181,8 @@ func evalInfixExpression( switch { case left.Type() == object.INTEGER_OBJ && right.Type() == object.INTEGER_OBJ: return evalIntegerInfixExpression(operator, left, right) + case left.Type() == object.STRING_OBJ && right.Type() == object.STRING_OBJ: + return evalStringInfixExpression(operator, left, right) case operator == "==": return nativeBoolToBooleanObject(left == right) case operator == "!=": @@ -221,6 +248,20 @@ func evalIntegerInfixExpression( } } +func evalStringInfixExpression( + operator string, + left, right object.Object, +) object.Object { + if operator != "+" { + return newError("unknown operator: %s %s %s", + left.Type(), operator, right.Type()) + } + + leftVal := left.(*object.String).Value + rightVal := right.(*object.String).Value + return &object.String{Value: leftVal + rightVal} +} + func evalIfExpression( ie *ast.IfExpression, env *object.Environment, @@ -243,12 +284,15 @@ func evalIdentifier( node *ast.Identifier, env *object.Environment, ) object.Object { - val, ok := env.Get(node.Value) - if !ok { - return newError("identifier not found: " + node.Value) + if val, ok := env.Get(node.Value); ok { + return val + } + + if builtin, ok := builtins[node.Value]; ok { + return builtin } - return val + return newError("identifier not found: " + node.Value) } func isTruthy(obj object.Object) bool { @@ -293,14 +337,19 @@ func evalExpressions( } func applyFunction(fn object.Object, args []object.Object) object.Object { - function, ok := fn.(*object.Function) - if !ok { + switch fn := fn.(type) { + + case *object.Function: + extendedEnv := extendFunctionEnv(fn, args) + evaluated := Eval(fn.Body, extendedEnv) + return unwrapReturnValue(evaluated) + + case *object.Builtin: + return fn.Fn(args...) + + default: return newError("not a function: %s", fn.Type()) } - - extendedEnv := extendFunctionEnv(function, args) - evaluated := Eval(function.Body, extendedEnv) - return unwrapReturnValue(evaluated) } func extendFunctionEnv( @@ -323,3 +372,71 @@ func unwrapReturnValue(obj object.Object) object.Object { return obj } + +func evalIndexExpression(left, index object.Object) object.Object { + switch { + case left.Type() == object.ARRAY_OBJ && index.Type() == object.INTEGER_OBJ: + return evalArrayIndexExpression(left, index) + case left.Type() == object.HASH_OBJ: + return evalHashIndexExpression(left, index) + default: + return newError("index operator not supported: %s", left.Type()) + } +} + +func evalArrayIndexExpression(array, index object.Object) object.Object { + arrayObject := array.(*object.Array) + idx := index.(*object.Integer).Value + max := int64(len(arrayObject.Elements) - 1) + + if idx < 0 || idx > max { + return NULL + } + + return arrayObject.Elements[idx] +} + +func evalHashLiteral( + node *ast.HashLiteral, + env *object.Environment, +) object.Object { + pairs := make(map[object.HashKey]object.HashPair) + + for keyNode, valueNode := range node.Pairs { + key := Eval(keyNode, env) + if isError(key) { + return key + } + + hashKey, ok := key.(object.Hashable) + if !ok { + return newError("unusable as hash key: %s", key.Type()) + } + + value := Eval(valueNode, env) + if isError(value) { + return value + } + + hashed := hashKey.HashKey() + pairs[hashed] = object.HashPair{Key: key, Value: value} + } + + return &object.Hash{Pairs: pairs} +} + +func evalHashIndexExpression(hash, index object.Object) object.Object { + hashObject := hash.(*object.Hash) + + key, ok := index.(object.Hashable) + if !ok { + return newError("unusable as hash key: %s", index.Type()) + } + + pair, ok := hashObject.Pairs[key.HashKey()] + if !ok { + return NULL + } + + return pair.Value +} diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index e8c77fa..932b07f 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -189,6 +189,10 @@ func TestErrorHandling(t *testing.T) { "5; true + false; 5", "unknown operator: BOOLEAN + BOOLEAN", }, + { + `"Hello" - "World"`, + "unknown operator: STRING - STRING", + }, { "if (10 > 1) { true + false; }", "unknown operator: BOOLEAN + BOOLEAN", @@ -209,6 +213,14 @@ if (10 > 1) { "foobar", "identifier not found: foobar", }, + { + `{"name": "Monkey"}[fn(x) { x }];`, + "unusable as hash key: FUNCTION", + }, + { + `999[1]`, + "index operator not supported: INTEGER", + }, } for _, tt := range tests { @@ -304,6 +316,272 @@ ourFunction(20) + first + second;` testIntegerObject(t, testEval(input), 70) } +func TestClosures(t *testing.T) { + input := ` +let newAdder = fn(x) { + fn(y) { x + y }; +}; + +let addTwo = newAdder(2); +addTwo(2);` + + testIntegerObject(t, testEval(input), 4) +} + +func TestStringLiteral(t *testing.T) { + input := `"Hello World!"` + + evaluated := testEval(input) + str, ok := evaluated.(*object.String) + if !ok { + t.Fatalf("object is not String. got=%T (%+v)", evaluated, evaluated) + } + + if str.Value != "Hello World!" { + t.Errorf("String has wrong value. got=%q", str.Value) + } +} + +func TestStringConcatenation(t *testing.T) { + input := `"Hello" + " " + "World!"` + + evaluated := testEval(input) + str, ok := evaluated.(*object.String) + if !ok { + t.Fatalf("object is not String. got=%T (%+v)", evaluated, evaluated) + } + + if str.Value != "Hello World!" { + t.Errorf("String has wrong value. got=%q", str.Value) + } +} + +func TestBuiltinFunctions(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + {`len("")`, 0}, + {`len("four")`, 4}, + {`len("hello world")`, 11}, + {`len(1)`, "argument to `len` not supported, got INTEGER"}, + {`len("one", "two")`, "wrong number of arguments. got=2, want=1"}, + {`len([1, 2, 3])`, 3}, + {`len([])`, 0}, + {`puts("hello", "world!")`, nil}, + {`first([1, 2, 3])`, 1}, + {`first([])`, nil}, + {`first(1)`, "argument to `first` must be ARRAY, got INTEGER"}, + {`last([1, 2, 3])`, 3}, + {`last([])`, nil}, + {`last(1)`, "argument to `last` must be ARRAY, got INTEGER"}, + {`rest([1, 2, 3])`, []int{2, 3}}, + {`rest([])`, nil}, + {`push([], 1)`, []int{1}}, + {`push(1)`, "argument to `push` must be ARRAY, got INTEGER"}, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + + switch expected := tt.expected.(type) { + case int: + testIntegerObject(t, evaluated, int64(expected)) + case nil: + testNullObject(t, evaluated) + case string: + errObj, ok := evaluated.(*object.Error) + if !ok { + t.Errorf("object is not Error. got=%T (%+v)", + evaluated, evaluated) + continue + } + if errObj.Message != expected { + t.Errorf("wrong error message. expected=%q, got=%q", + expected, errObj.Message) + } + case []int: + array, ok := evaluated.(*object.Array) + if !ok { + t.Errorf("obj not Array. got=%T (%+v)", evaluated, evaluated) + continue + } + + if len(array.Elements) != len(expected) { + t.Errorf("wrong num of elements. want=%d, got=%d", + len(expected), len(array.Elements)) + continue + } + + for i, expectedElem := range expected { + testIntegerObject(t, array.Elements[i], int64(expectedElem)) + } + } + } +} + +func TestArrayLiterals(t *testing.T) { + input := "[1, 2 * 2, 3 + 3]" + + evaluated := testEval(input) + result, ok := evaluated.(*object.Array) + if !ok { + t.Fatalf("object is not Array. got=%T (%+v)", evaluated, evaluated) + } + + if len(result.Elements) != 3 { + t.Fatalf("array has wrong num of elements. got=%d", + len(result.Elements)) + } + + testIntegerObject(t, result.Elements[0], 1) + testIntegerObject(t, result.Elements[1], 4) + testIntegerObject(t, result.Elements[2], 6) +} + +func TestArrayIndexExpressions(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + { + "[1, 2, 3][0]", + 1, + }, + { + "[1, 2, 3][1]", + 2, + }, + { + "[1, 2, 3][2]", + 3, + }, + { + "let i = 0; [1][i];", + 1, + }, + { + "[1, 2, 3][1 + 1];", + 3, + }, + { + "let myArray = [1, 2, 3]; myArray[2];", + 3, + }, + { + "let myArray = [1, 2, 3]; myArray[0] + myArray[1] + myArray[2];", + 6, + }, + { + "let myArray = [1, 2, 3]; let i = myArray[0]; myArray[i]", + 2, + }, + { + "[1, 2, 3][3]", + nil, + }, + { + "[1, 2, 3][-1]", + nil, + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + integer, ok := tt.expected.(int) + if ok { + testIntegerObject(t, evaluated, int64(integer)) + } else { + testNullObject(t, evaluated) + } + } +} + +func TestHashLiterals(t *testing.T) { + input := `let two = "two"; + { + "one": 10 - 9, + two: 1 + 1, + "thr" + "ee": 6 / 2, + 4: 4, + true: 5, + false: 6 + }` + + evaluated := testEval(input) + result, ok := evaluated.(*object.Hash) + if !ok { + t.Fatalf("Eval didn't return Hash. got=%T (%+v)", evaluated, evaluated) + } + + expected := map[object.HashKey]int64{ + (&object.String{Value: "one"}).HashKey(): 1, + (&object.String{Value: "two"}).HashKey(): 2, + (&object.String{Value: "three"}).HashKey(): 3, + (&object.Integer{Value: 4}).HashKey(): 4, + TRUE.HashKey(): 5, + FALSE.HashKey(): 6, + } + + if len(result.Pairs) != len(expected) { + t.Fatalf("Hash has wrong num of pairs. got=%d", len(result.Pairs)) + } + + for expectedKey, expectedValue := range expected { + pair, ok := result.Pairs[expectedKey] + if !ok { + t.Errorf("no pair for given key in Pairs") + } + + testIntegerObject(t, pair.Value, expectedValue) + } +} + +func TestHashIndexExpressions(t *testing.T) { + tests := []struct { + input string + expected interface{} + }{ + { + `{"foo": 5}["foo"]`, + 5, + }, + { + `{"foo": 5}["bar"]`, + nil, + }, + { + `let key = "foo"; {"foo": 5}[key]`, + 5, + }, + { + `{}["foo"]`, + nil, + }, + { + `{5: 5}[5]`, + 5, + }, + { + `{true: 5}[true]`, + 5, + }, + { + `{false: 5}[false]`, + 5, + }, + } + + for _, tt := range tests { + evaluated := testEval(tt.input) + integer, ok := tt.expected.(int) + if ok { + testIntegerObject(t, evaluated, int64(integer)) + } else { + testNullObject(t, evaluated) + } + } +} func testEval(input string) object.Object { l := lexer.New(input) p := parser.New(l) diff --git a/lexer/lexer.go b/lexer/lexer.go index cffd925..8deda73 100644 --- a/lexer/lexer.go +++ b/lexer/lexer.go @@ -51,6 +51,8 @@ func (l *Lexer) NextToken() token.Token { tok = newToken(token.GT, l.ch) case ';': tok = newToken(token.SEMICOLON, l.ch) + case ':': + tok = newToken(token.COLON, l.ch) case ',': tok = newToken(token.COMMA, l.ch) case '{': @@ -61,6 +63,13 @@ func (l *Lexer) NextToken() token.Token { tok = newToken(token.LPAREN, l.ch) case ')': tok = newToken(token.RPAREN, l.ch) + case '"': + tok.Type = token.STRING + tok.Literal = l.readString() + case '[': + tok = newToken(token.LBRACKET, l.ch) + case ']': + tok = newToken(token.RBRACKET, l.ch) case 0: tok.Literal = "" tok.Type = token.EOF @@ -122,6 +131,17 @@ func (l *Lexer) readNumber() string { return l.input[position:l.position] } +func (l *Lexer) readString() string { + position := l.position + 1 + for { + l.readChar() + if l.ch == '"' { + break + } + } + return l.input[position:l.position] +} + func isLetter(ch byte) bool { return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' } diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 0a7f248..e4e64a4 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -26,6 +26,10 @@ if (5 < 10) { 10 == 10; 10 != 9; +"foobar" +"foo bar" +[1, 2]; +{"foo": "bar"} ` tests := []struct { @@ -105,6 +109,19 @@ if (5 < 10) { {token.NOT_EQ, "!="}, {token.INT, "9"}, {token.SEMICOLON, ";"}, + {token.STRING, "foobar"}, + {token.STRING, "foo bar"}, + {token.LBRACKET, "["}, + {token.INT, "1"}, + {token.COMMA, ","}, + {token.INT, "2"}, + {token.RBRACKET, "]"}, + {token.SEMICOLON, ";"}, + {token.LBRACE, "{"}, + {token.STRING, "foo"}, + {token.COLON, ":"}, + {token.STRING, "bar"}, + {token.RBRACE, "}"}, {token.EOF, ""}, } diff --git a/object/object.go b/object/object.go index cdde084..2c2a1b0 100644 --- a/object/object.go +++ b/object/object.go @@ -3,10 +3,13 @@ package object import ( "bytes" "fmt" + "hash/fnv" "monkey/ast" "strings" ) +type BuiltinFunction func(args ...Object) Object + type ObjectType string const ( @@ -15,12 +18,26 @@ const ( INTEGER_OBJ = "INTEGER" BOOLEAN_OBJ = "BOOLEAN" + STRING_OBJ = "STRING" RETURN_VALUE_OBJ = "RETURN_VALUE" FUNCTION_OBJ = "FUNCTION" + BUILTIN_OBJ = "BUILTIN" + + ARRAY_OBJ = "ARRAY" + HASH_OBJ = "HASH" ) +type HashKey struct { + Type ObjectType + Value uint64 +} + +type Hashable interface { + HashKey() HashKey +} + type Object interface { Type() ObjectType Inspect() string @@ -32,6 +49,9 @@ type Integer struct { func (i *Integer) Type() ObjectType { return INTEGER_OBJ } func (i *Integer) Inspect() string { return fmt.Sprintf("%d", i.Value) } +func (i *Integer) HashKey() HashKey { + return HashKey{Type: i.Type(), Value: uint64(i.Value)} +} type Boolean struct { Value bool @@ -39,6 +59,17 @@ type Boolean struct { func (b *Boolean) Type() ObjectType { return BOOLEAN_OBJ } func (b *Boolean) Inspect() string { return fmt.Sprintf("%t", b.Value) } +func (b *Boolean) HashKey() HashKey { + var value uint64 + + if b.Value { + value = 1 + } else { + value = 0 + } + + return HashKey{Type: b.Type(), Value: value} +} type Null struct{} @@ -83,3 +114,69 @@ func (f *Function) Inspect() string { return out.String() } + +type String struct { + Value string +} + +func (s *String) Type() ObjectType { return STRING_OBJ } +func (s *String) Inspect() string { return s.Value } +func (s *String) HashKey() HashKey { + h := fnv.New64a() + h.Write([]byte(s.Value)) + + return HashKey{Type: s.Type(), Value: h.Sum64()} +} + +type Builtin struct { + Fn BuiltinFunction +} + +func (b *Builtin) Type() ObjectType { return BUILTIN_OBJ } +func (b *Builtin) Inspect() string { return "builtin function" } + +type Array struct { + Elements []Object +} + +func (ao *Array) Type() ObjectType { return ARRAY_OBJ } +func (ao *Array) Inspect() string { + var out bytes.Buffer + + elements := []string{} + for _, e := range ao.Elements { + elements = append(elements, e.Inspect()) + } + + out.WriteString("[") + out.WriteString(strings.Join(elements, ", ")) + out.WriteString("]") + + return out.String() +} + +type HashPair struct { + Key Object + Value Object +} + +type Hash struct { + Pairs map[HashKey]HashPair +} + +func (h *Hash) Type() ObjectType { return HASH_OBJ } +func (h *Hash) Inspect() string { + var out bytes.Buffer + + pairs := []string{} + for _, pair := range h.Pairs { + pairs = append(pairs, fmt.Sprintf("%s: %s", + pair.Key.Inspect(), pair.Value.Inspect())) + } + + out.WriteString("{") + out.WriteString(strings.Join(pairs, ", ")) + out.WriteString("}") + + return out.String() +} diff --git a/object/object_test.go b/object/object_test.go new file mode 100644 index 0000000..63228a2 --- /dev/null +++ b/object/object_test.go @@ -0,0 +1,60 @@ +package object + +import "testing" + +func TestStringHashKey(t *testing.T) { + hello1 := &String{Value: "Hello World"} + hello2 := &String{Value: "Hello World"} + diff1 := &String{Value: "My name is johnny"} + diff2 := &String{Value: "My name is johnny"} + + if hello1.HashKey() != hello2.HashKey() { + t.Errorf("strings with same content have different hash keys") + } + + if diff1.HashKey() != diff2.HashKey() { + t.Errorf("strings with same content have different hash keys") + } + + if hello1.HashKey() == diff1.HashKey() { + t.Errorf("strings with different content have same hash keys") + } +} + +func TestBooleanHashKey(t *testing.T) { + true1 := &Boolean{Value: true} + true2 := &Boolean{Value: true} + false1 := &Boolean{Value: false} + false2 := &Boolean{Value: false} + + if true1.HashKey() != true2.HashKey() { + t.Errorf("trues do not have same hash key") + } + + if false1.HashKey() != false2.HashKey() { + t.Errorf("falses do not have same hash key") + } + + if true1.HashKey() == false1.HashKey() { + t.Errorf("true has same hash key as false") + } +} + +func TestIntegerHashKey(t *testing.T) { + one1 := &Integer{Value: 1} + one2 := &Integer{Value: 1} + two1 := &Integer{Value: 2} + two2 := &Integer{Value: 2} + + if one1.HashKey() != one2.HashKey() { + t.Errorf("integers with same content have twoerent hash keys") + } + + if two1.HashKey() != two2.HashKey() { + t.Errorf("integers with same content have twoerent hash keys") + } + + if one1.HashKey() == two1.HashKey() { + t.Errorf("integers with twoerent content have same hash keys") + } +} diff --git a/parser/parser.go b/parser/parser.go index 0c4a9b7..e02595c 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -17,6 +17,7 @@ const ( PRODUCT // * PREFIX // -X or !X CALL // myFunction(X) + INDEX // array[index] ) var precedences = map[token.TokenType]int{ @@ -29,6 +30,7 @@ var precedences = map[token.TokenType]int{ token.SLASH: PRODUCT, token.ASTERISK: PRODUCT, token.LPAREN: CALL, + token.LBRACKET: INDEX, } type ( @@ -56,6 +58,7 @@ func New(l *lexer.Lexer) *Parser { p.prefixParseFns = make(map[token.TokenType]prefixParseFn) p.registerPrefix(token.IDENT, p.parseIdentifier) p.registerPrefix(token.INT, p.parseIntegerLiteral) + p.registerPrefix(token.STRING, p.parseStringLiteral) p.registerPrefix(token.BANG, p.parsePrefixExpression) p.registerPrefix(token.MINUS, p.parsePrefixExpression) p.registerPrefix(token.TRUE, p.parseBoolean) @@ -63,6 +66,8 @@ func New(l *lexer.Lexer) *Parser { p.registerPrefix(token.LPAREN, p.parseGroupedExpression) p.registerPrefix(token.IF, p.parseIfExpression) p.registerPrefix(token.FUNCTION, p.parseFunctionLiteral) + p.registerPrefix(token.LBRACKET, p.parseArrayLiteral) + p.registerPrefix(token.LBRACE, p.parseHashLiteral) p.infixParseFns = make(map[token.TokenType]infixParseFn) p.registerInfix(token.PLUS, p.parseInfixExpression) @@ -75,6 +80,7 @@ func New(l *lexer.Lexer) *Parser { p.registerInfix(token.GT, p.parseInfixExpression) p.registerInfix(token.LPAREN, p.parseCallExpression) + p.registerInfix(token.LBRACKET, p.parseIndexExpression) // Read two tokens, so curToken and peekToken are both set p.nextToken() @@ -254,6 +260,10 @@ func (p *Parser) parseIntegerLiteral() ast.Expression { return lit } +func (p *Parser) parseStringLiteral() ast.Expression { + return &ast.StringLiteral{Token: p.curToken, Value: p.curToken.Literal} +} + func (p *Parser) parsePrefixExpression() ast.Expression { expression := &ast.PrefixExpression{ Token: p.curToken, @@ -394,32 +404,82 @@ func (p *Parser) parseFunctionParameters() []*ast.Identifier { func (p *Parser) parseCallExpression(function ast.Expression) ast.Expression { exp := &ast.CallExpression{Token: p.curToken, Function: function} - exp.Arguments = p.parseCallArguments() + exp.Arguments = p.parseExpressionList(token.RPAREN) return exp } -func (p *Parser) parseCallArguments() []ast.Expression { - args := []ast.Expression{} +func (p *Parser) parseExpressionList(end token.TokenType) []ast.Expression { + list := []ast.Expression{} - if p.peekTokenIs(token.RPAREN) { + if p.peekTokenIs(end) { p.nextToken() - return args + return list } p.nextToken() - args = append(args, p.parseExpression(LOWEST)) + list = append(list, p.parseExpression(LOWEST)) for p.peekTokenIs(token.COMMA) { p.nextToken() p.nextToken() - args = append(args, p.parseExpression(LOWEST)) + list = append(list, p.parseExpression(LOWEST)) } - if !p.expectPeek(token.RPAREN) { + if !p.expectPeek(end) { + return nil + } + + return list +} + +func (p *Parser) parseArrayLiteral() ast.Expression { + array := &ast.ArrayLiteral{Token: p.curToken} + + array.Elements = p.parseExpressionList(token.RBRACKET) + + return array +} + +func (p *Parser) parseIndexExpression(left ast.Expression) ast.Expression { + exp := &ast.IndexExpression{Token: p.curToken, Left: left} + + p.nextToken() + exp.Index = p.parseExpression(LOWEST) + + if !p.expectPeek(token.RBRACKET) { + return nil + } + + return exp +} + +func (p *Parser) parseHashLiteral() ast.Expression { + hash := &ast.HashLiteral{Token: p.curToken} + hash.Pairs = make(map[ast.Expression]ast.Expression) + + for !p.peekTokenIs(token.RBRACE) { + p.nextToken() + key := p.parseExpression(LOWEST) + + if !p.expectPeek(token.COLON) { + return nil + } + + p.nextToken() + value := p.parseExpression(LOWEST) + + hash.Pairs[key] = value + + if !p.peekTokenIs(token.RBRACE) && !p.expectPeek(token.COMMA) { + return nil + } + } + + if !p.expectPeek(token.RBRACE) { return nil } - return args + return hash } func (p *Parser) registerPrefix(tokenType token.TokenType, fn prefixParseFn) { diff --git a/parser/parser_test.go b/parser/parser_test.go index fa080aa..e187c9f 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -345,6 +345,14 @@ func TestOperatorPrecedenceParsing(t *testing.T) { "add(a + b + c * d / f + g)", "add((((a + b) + ((c * d) / f)) + g))", }, + { + "a * [1, 2, 3, 4][b * c] * d", + "((a * ([1, 2, 3, 4][(b * c)])) * d)", + }, + { + "add(a * b[2], b[1], 2 * [1, 2][1])", + "add((a * (b[2])), (b[1]), (2 * ([1, 2][1])))", + }, } for _, tt := range tests { @@ -674,6 +682,264 @@ func TestCallExpressionParameterParsing(t *testing.T) { } } +func TestStringLiteralExpression(t *testing.T) { + input := `"hello world";` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*ast.ExpressionStatement) + literal, ok := stmt.Expression.(*ast.StringLiteral) + if !ok { + t.Fatalf("exp not *ast.StringLiteral. got=%T", stmt.Expression) + } + + if literal.Value != "hello world" { + t.Errorf("literal.Value not %q. got=%q", "hello world", literal.Value) + } +} + +func TestParsingEmptyArrayLiterals(t *testing.T) { + input := "[]" + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + array, ok := stmt.Expression.(*ast.ArrayLiteral) + if !ok { + t.Fatalf("exp not ast.ArrayLiteral. got=%T", stmt.Expression) + } + + if len(array.Elements) != 0 { + t.Errorf("len(array.Elements) not 0. got=%d", len(array.Elements)) + } +} + +func TestParsingArrayLiterals(t *testing.T) { + input := "[1, 2 * 2, 3 + 3]" + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + array, ok := stmt.Expression.(*ast.ArrayLiteral) + if !ok { + t.Fatalf("exp not ast.ArrayLiteral. got=%T", stmt.Expression) + } + + if len(array.Elements) != 3 { + t.Fatalf("len(array.Elements) not 3. got=%d", len(array.Elements)) + } + + testIntegerLiteral(t, array.Elements[0], 1) + testInfixExpression(t, array.Elements[1], 2, "*", 2) + testInfixExpression(t, array.Elements[2], 3, "+", 3) +} + +func TestParsingIndexExpressions(t *testing.T) { + input := "myArray[1 + 1]" + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt, ok := program.Statements[0].(*ast.ExpressionStatement) + indexExp, ok := stmt.Expression.(*ast.IndexExpression) + if !ok { + t.Fatalf("exp not *ast.IndexExpression. got=%T", stmt.Expression) + } + + if !testIdentifier(t, indexExp.Left, "myArray") { + return + } + + if !testInfixExpression(t, indexExp.Index, 1, "+", 1) { + return + } +} + +func TestParsingEmptyHashLiteral(t *testing.T) { + input := "{}" + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*ast.ExpressionStatement) + hash, ok := stmt.Expression.(*ast.HashLiteral) + if !ok { + t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression) + } + + if len(hash.Pairs) != 0 { + t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs)) + } +} + +func TestParsingHashLiteralsStringKeys(t *testing.T) { + input := `{"one": 1, "two": 2, "three": 3}` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*ast.ExpressionStatement) + hash, ok := stmt.Expression.(*ast.HashLiteral) + if !ok { + t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression) + } + + expected := map[string]int64{ + "one": 1, + "two": 2, + "three": 3, + } + + if len(hash.Pairs) != len(expected) { + t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs)) + } + + for key, value := range hash.Pairs { + literal, ok := key.(*ast.StringLiteral) + if !ok { + t.Errorf("key is not ast.StringLiteral. got=%T", key) + continue + } + + expectedValue := expected[literal.String()] + testIntegerLiteral(t, value, expectedValue) + } +} + +func TestParsingHashLiteralsBooleanKeys(t *testing.T) { + input := `{true: 1, false: 2}` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*ast.ExpressionStatement) + hash, ok := stmt.Expression.(*ast.HashLiteral) + if !ok { + t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression) + } + + expected := map[string]int64{ + "true": 1, + "false": 2, + } + + if len(hash.Pairs) != len(expected) { + t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs)) + } + + for key, value := range hash.Pairs { + boolean, ok := key.(*ast.Boolean) + if !ok { + t.Errorf("key is not ast.BooleanLiteral. got=%T", key) + continue + } + + expectedValue := expected[boolean.String()] + testIntegerLiteral(t, value, expectedValue) + } +} + +func TestParsingHashLiteralsIntegerKeys(t *testing.T) { + input := `{1: 1, 2: 2, 3: 3}` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*ast.ExpressionStatement) + hash, ok := stmt.Expression.(*ast.HashLiteral) + if !ok { + t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression) + } + + expected := map[string]int64{ + "1": 1, + "2": 2, + "3": 3, + } + + if len(hash.Pairs) != len(expected) { + t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs)) + } + + for key, value := range hash.Pairs { + integer, ok := key.(*ast.IntegerLiteral) + if !ok { + t.Errorf("key is not ast.IntegerLiteral. got=%T", key) + continue + } + + expectedValue := expected[integer.String()] + + testIntegerLiteral(t, value, expectedValue) + } +} + +func TestParsingHashLiteralsWithExpressions(t *testing.T) { + input := `{"one": 0 + 1, "two": 10 - 8, "three": 15 / 5}` + + l := lexer.New(input) + p := New(l) + program := p.ParseProgram() + checkParserErrors(t, p) + + stmt := program.Statements[0].(*ast.ExpressionStatement) + hash, ok := stmt.Expression.(*ast.HashLiteral) + if !ok { + t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression) + } + + if len(hash.Pairs) != 3 { + t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs)) + } + + tests := map[string]func(ast.Expression){ + "one": func(e ast.Expression) { + testInfixExpression(t, e, 0, "+", 1) + }, + "two": func(e ast.Expression) { + testInfixExpression(t, e, 10, "-", 8) + }, + "three": func(e ast.Expression) { + testInfixExpression(t, e, 15, "/", 5) + }, + } + + for key, value := range hash.Pairs { + literal, ok := key.(*ast.StringLiteral) + if !ok { + t.Errorf("key is not ast.StringLiteral. got=%T", key) + continue + } + + testFunc, ok := tests[literal.String()] + if !ok { + t.Errorf("No test function for key %q found", literal.String()) + continue + } + + testFunc(value) + } +} + func testLetStatement(t *testing.T, s ast.Statement, name string) bool { if s.TokenLiteral() != "let" { t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral()) diff --git a/token/token.go b/token/token.go index 12158fa..3d2d2f7 100644 --- a/token/token.go +++ b/token/token.go @@ -7,8 +7,9 @@ const ( EOF = "EOF" // Identifiers + literals - IDENT = "IDENT" // add, foobar, x, y, ... - INT = "INT" // 1343456 + IDENT = "IDENT" // add, foobar, x, y, ... + INT = "INT" // 1343456 + STRING = "STRING" // "foobar" // Operators ASSIGN = "=" @@ -27,11 +28,14 @@ const ( // Delimiters COMMA = "," SEMICOLON = ";" + COLON = ":" - LPAREN = "(" - RPAREN = ")" - LBRACE = "{" - RBRACE = "}" + LPAREN = "(" + RPAREN = ")" + LBRACE = "{" + RBRACE = "}" + LBRACKET = "[" + RBRACKET = "]" // Keywords FUNCTION = "FUNCTION" -- cgit v1.2.3