Skip to content

Commit ff5ad99

Browse files
committed
Add support for url, json number, and decimal
GODRIVER-363 GODRIVER-343 Change-Id: I3a7e4198beb878b7f38f0a296b3be7fab604148f
1 parent 5c4209e commit ff5ad99

File tree

4 files changed

+379
-55
lines changed

4 files changed

+379
-55
lines changed

bson/decode.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import (
1010
"fmt"
1111
"io"
1212
"math"
13+
"net/url"
1314
"reflect"
15+
"strconv"
1416
"strings"
1517
"time"
1618

@@ -355,16 +357,26 @@ func (d *Decoder) getReflectValue(v *Value, containerType reflect.Type, outer re
355357

356358
case tFloat64, tEmpty:
357359
val = reflect.ValueOf(f)
360+
case tJSONNumber:
361+
val = reflect.ValueOf(strconv.FormatFloat(f, 'f', -1, 64)).Convert(tJSONNumber)
358362
default:
359363
return val, nil
360364
}
361365

362366
case 0x2:
363-
if containerType != tString && containerType != tEmpty {
367+
str := v.StringValue()
368+
switch containerType {
369+
case tString, tEmpty:
370+
val = reflect.ValueOf(str)
371+
case tURL:
372+
u, err := url.Parse(str)
373+
if err != nil {
374+
return val, err
375+
}
376+
val = reflect.ValueOf(u).Elem()
377+
default:
364378
return val, nil
365379
}
366-
367-
val = reflect.ValueOf(v.StringValue())
368380
case 0x4:
369381
if containerType == tEmpty {
370382
d := NewDecoder(bytes.NewBuffer(v.ReaderArray()))
@@ -547,6 +559,8 @@ func (d *Decoder) getReflectValue(v *Value, containerType reflect.Type, outer re
547559

548560
case tEmpty, tInt32, tInt64, tInt, tFloat32, tFloat64:
549561
val = reflect.ValueOf(i).Convert(containerType)
562+
case tJSONNumber:
563+
val = reflect.ValueOf(strconv.FormatInt(int64(i), 10)).Convert(tJSONNumber)
550564
default:
551565
return val, nil
552566
}
@@ -609,6 +623,8 @@ func (d *Decoder) getReflectValue(v *Value, containerType reflect.Type, outer re
609623
val = reflect.ValueOf(float32(i))
610624
case tFloat64:
611625
val = reflect.ValueOf(float64(i))
626+
case tJSONNumber:
627+
val = reflect.ValueOf(strconv.FormatInt(i, 10)).Convert(tJSONNumber)
612628
}
613629

614630
case 0x13:

bson/decode_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ package bson
88

99
import (
1010
"bytes"
11+
"encoding/json"
12+
"net/url"
1113
"reflect"
1214
"testing"
1315

1416
"github.com/google/go-cmp/cmp"
17+
"github.com/mongodb/mongo-go-driver/bson/decimal"
1518
"github.com/stretchr/testify/require"
1619
)
1720

@@ -2007,6 +2010,47 @@ func TestDecoder(t *testing.T) {
20072010
return
20082011
}
20092012

2013+
require.True(t, reflect.DeepEqual(tc.expected, tc.actual))
2014+
})
2015+
}
2016+
})
2017+
t.Run("decimal128", func(t *testing.T) {
2018+
decimal128, err := decimal.ParseDecimal128("1.5e10")
2019+
if err != nil {
2020+
t.Errorf("Error parsing decimal128: %v", err)
2021+
t.FailNow()
2022+
}
2023+
testCases := []struct {
2024+
name string
2025+
reader []byte
2026+
expected interface{}
2027+
actual interface{}
2028+
err error
2029+
}{
2030+
{
2031+
"decimal128",
2032+
docToBytes(NewDocument(EC.Decimal128("a", decimal128))),
2033+
&struct {
2034+
A decimal.Decimal128
2035+
}{
2036+
A: decimal128,
2037+
},
2038+
&struct {
2039+
A decimal.Decimal128
2040+
}{},
2041+
nil,
2042+
},
2043+
}
2044+
for _, tc := range testCases {
2045+
t.Run(tc.name, func(t *testing.T) {
2046+
d := NewDecoder(bytes.NewBuffer(tc.reader))
2047+
2048+
err := d.Decode(tc.actual)
2049+
requireErrEqual(t, tc.err, err)
2050+
if err != nil {
2051+
return
2052+
}
2053+
20102054
require.True(t, reflect.DeepEqual(tc.expected, tc.actual))
20112055
})
20122056
}
@@ -2382,6 +2426,63 @@ func TestDecoder(t *testing.T) {
23822426
})
23832427
}
23842428
})
2429+
t.Run("pluggable types", func(t *testing.T) {
2430+
murl, err := url.Parse("https://mongodb.com/random-url?hello=world")
2431+
if err != nil {
2432+
t.Errorf("Error parsing URL: %v", err)
2433+
t.FailNow()
2434+
}
2435+
testCases := []struct {
2436+
name string
2437+
reader []byte
2438+
expected interface{}
2439+
actual interface{}
2440+
err error
2441+
}{
2442+
{
2443+
"*url.URL",
2444+
docToBytes(NewDocument(EC.String("a", murl.String()))),
2445+
&struct {
2446+
A *url.URL
2447+
}{
2448+
A: murl,
2449+
},
2450+
&struct {
2451+
A *url.URL
2452+
}{},
2453+
nil,
2454+
},
2455+
{
2456+
"json.Number",
2457+
docToBytes(NewDocument(EC.Int64("a", 5), EC.Double("b", 10.10))),
2458+
&struct {
2459+
A json.Number
2460+
B json.Number
2461+
}{
2462+
A: json.Number("5"),
2463+
B: json.Number("10.1"),
2464+
},
2465+
&struct {
2466+
A json.Number
2467+
B json.Number
2468+
}{},
2469+
nil,
2470+
},
2471+
}
2472+
for _, tc := range testCases {
2473+
t.Run(tc.name, func(t *testing.T) {
2474+
d := NewDecoder(bytes.NewBuffer(tc.reader))
2475+
2476+
err := d.Decode(tc.actual)
2477+
requireErrEqual(t, tc.err, err)
2478+
if err != nil {
2479+
return
2480+
}
2481+
2482+
require.True(t, reflect.DeepEqual(tc.expected, tc.actual))
2483+
})
2484+
}
2485+
})
23852486
}
23862487

23872488
func elementSliceEqual(t *testing.T, e1 []*Element, e2 []*Element) {

bson/encode.go

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,18 @@
77
package bson
88

99
import (
10+
"encoding/json"
1011
"errors"
1112
"fmt"
1213
"io"
1314
"math"
15+
"net/url"
1416
"reflect"
1517
"strconv"
1618
"strings"
1719
"time"
1820

21+
"github.com/mongodb/mongo-go-driver/bson/decimal"
1922
"github.com/mongodb/mongo-go-driver/bson/objectid"
2023
)
2124

@@ -25,6 +28,8 @@ var ErrEncoderNilWriter = errors.New("encoder.Encode called on Encoder with nil
2528
var tByteSlice = reflect.TypeOf(([]byte)(nil))
2629
var tByte = reflect.TypeOf(byte(0x00))
2730
var tElement = reflect.TypeOf((*Element)(nil))
31+
var tURL = reflect.TypeOf(url.URL{})
32+
var tJSONNumber = reflect.TypeOf(json.Number(""))
2833

2934
// Marshaler describes a type that can marshal a BSON representation of itself into bytes.
3035
type Marshaler interface {
@@ -278,6 +283,7 @@ func (e *encoder) encodeMap(val reflect.Value) ([]*Element, error) {
278283
mapkeys := val.MapKeys()
279284
elems := make([]*Element, 0, val.Len())
280285
for _, rkey := range mapkeys {
286+
orig := rkey
281287
rkey = e.underlyingVal(rkey)
282288

283289
var key string
@@ -297,9 +303,15 @@ func (e *encoder) encodeMap(val reflect.Value) ([]*Element, error) {
297303
case reflect.String:
298304
key = rkey.String()
299305
default:
300-
if rkey.Type() == tOID {
306+
switch rkey.Type() {
307+
case tOID:
301308
key = fmt.Sprintf("%s", rkey.Interface())
302-
} else {
309+
case tURL:
310+
rkey = orig
311+
key = fmt.Sprintf("%s", rkey.Interface())
312+
case tDecimal:
313+
key = fmt.Sprintf("%s", rkey.Interface())
314+
default:
303315
return nil, fmt.Errorf("Unsupported map key type %s", rkey.Kind())
304316
}
305317
}
@@ -316,6 +328,24 @@ func (e *encoder) encodeMap(val reflect.Value) ([]*Element, error) {
316328
case Reader:
317329
elems = append(elems, EC.SubDocumentFromReader(key, t))
318330
continue
331+
case json.Number:
332+
// We try to do an int first
333+
if i64, err := t.Int64(); err == nil {
334+
elems = append(elems, EC.Int64(key, i64))
335+
continue
336+
}
337+
f64, err := t.Float64()
338+
if err != nil {
339+
return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err)
340+
}
341+
elems = append(elems, EC.Double(key, f64))
342+
continue
343+
case *url.URL:
344+
elems = append(elems, EC.String(key, t.String()))
345+
continue
346+
case decimal.Decimal128:
347+
elems = append(elems, EC.Decimal128(key, t))
348+
continue
319349
}
320350
rval = e.underlyingVal(rval)
321351

@@ -343,6 +373,24 @@ func (e *encoder) encodeSlice(val reflect.Value) ([]*Element, error) {
343373
case Reader:
344374
elems = append(elems, EC.SubDocumentFromReader(key, t))
345375
continue
376+
case json.Number:
377+
// We try to do an int first
378+
if i64, err := t.Int64(); err == nil {
379+
elems = append(elems, EC.Int64(key, i64))
380+
continue
381+
}
382+
f64, err := t.Float64()
383+
if err != nil {
384+
return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err)
385+
}
386+
elems = append(elems, EC.Double(key, f64))
387+
continue
388+
case *url.URL:
389+
elems = append(elems, EC.String(key, t.String()))
390+
continue
391+
case decimal.Decimal128:
392+
elems = append(elems, EC.Decimal128(key, t))
393+
continue
346394
}
347395
sval = e.underlyingVal(sval)
348396
elem, err := e.elemFromValue(key, sval, false)
@@ -371,6 +419,24 @@ func (e *encoder) encodeSliceAsArray(rval reflect.Value, minsize bool) ([]*Value
371419
case Reader:
372420
vals = append(vals, VC.DocumentFromReader(t))
373421
continue
422+
case json.Number:
423+
// We try to do an int first
424+
if i64, err := t.Int64(); err == nil {
425+
vals = append(vals, VC.Int64(i64))
426+
continue
427+
}
428+
f64, err := t.Float64()
429+
if err != nil {
430+
return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err)
431+
}
432+
vals = append(vals, VC.Double(f64))
433+
continue
434+
case *url.URL:
435+
vals = append(vals, VC.String(t.String()))
436+
continue
437+
case decimal.Decimal128:
438+
vals = append(vals, VC.Decimal128(t))
439+
continue
374440
}
375441

376442
sval = e.underlyingVal(sval)
@@ -429,6 +495,24 @@ func (e *encoder) encodeStruct(val reflect.Value) ([]*Element, error) {
429495
case Reader:
430496
elems = append(elems, EC.SubDocumentFromReader(key, t))
431497
continue
498+
case json.Number:
499+
// We try to do an int first
500+
if i64, err := t.Int64(); err == nil {
501+
elems = append(elems, EC.Int64(key, i64))
502+
continue
503+
}
504+
f64, err := t.Float64()
505+
if err != nil {
506+
return nil, fmt.Errorf("Invalid json.Number used as map value: %s", err)
507+
}
508+
elems = append(elems, EC.Double(key, f64))
509+
continue
510+
case *url.URL:
511+
elems = append(elems, EC.String(key, t.String()))
512+
continue
513+
case decimal.Decimal128:
514+
elems = append(elems, EC.Decimal128(key, t))
515+
continue
432516
}
433517
field = e.underlyingVal(field)
434518

0 commit comments

Comments
 (0)