368 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			368 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (C) MongoDB, Inc. 2017-present.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License"); you may
 | |
| // not use this file except in compliance with the License. You may obtain
 | |
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| package bsoncodec
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 
 | |
| 	"go.mongodb.org/mongo-driver/bson/bsonrw"
 | |
| 	"go.mongodb.org/mongo-driver/bson/bsontype"
 | |
| )
 | |
| 
 | |
| var defaultStructCodec = &StructCodec{
 | |
| 	cache:  make(map[reflect.Type]*structDescription),
 | |
| 	parser: DefaultStructTagParser,
 | |
| }
 | |
| 
 | |
| // Zeroer allows custom struct types to implement a report of zero
 | |
| // state. All struct types that don't implement Zeroer or where IsZero
 | |
| // returns false are considered to be not zero.
 | |
| type Zeroer interface {
 | |
| 	IsZero() bool
 | |
| }
 | |
| 
 | |
| // StructCodec is the Codec used for struct values.
 | |
| type StructCodec struct {
 | |
| 	cache  map[reflect.Type]*structDescription
 | |
| 	l      sync.RWMutex
 | |
| 	parser StructTagParser
 | |
| }
 | |
| 
 | |
| var _ ValueEncoder = &StructCodec{}
 | |
| var _ ValueDecoder = &StructCodec{}
 | |
| 
 | |
| // NewStructCodec returns a StructCodec that uses p for struct tag parsing.
 | |
| func NewStructCodec(p StructTagParser) (*StructCodec, error) {
 | |
| 	if p == nil {
 | |
| 		return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
 | |
| 	}
 | |
| 
 | |
| 	return &StructCodec{
 | |
| 		cache:  make(map[reflect.Type]*structDescription),
 | |
| 		parser: p,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // EncodeValue handles encoding generic struct types.
 | |
| func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
 | |
| 	if !val.IsValid() || val.Kind() != reflect.Struct {
 | |
| 		return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
 | |
| 	}
 | |
| 
 | |
| 	sd, err := sc.describeStruct(r.Registry, val.Type())
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	dw, err := vw.WriteDocument()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	var rv reflect.Value
 | |
| 	for _, desc := range sd.fl {
 | |
| 		if desc.inline == nil {
 | |
| 			rv = val.Field(desc.idx)
 | |
| 		} else {
 | |
| 			rv = val.FieldByIndex(desc.inline)
 | |
| 		}
 | |
| 
 | |
| 		if desc.encoder == nil {
 | |
| 			return ErrNoEncoder{Type: rv.Type()}
 | |
| 		}
 | |
| 
 | |
| 		encoder := desc.encoder
 | |
| 
 | |
| 		iszero := sc.isZero
 | |
| 		if iz, ok := encoder.(CodecZeroer); ok {
 | |
| 			iszero = iz.IsTypeZero
 | |
| 		}
 | |
| 
 | |
| 		if desc.omitEmpty && iszero(rv.Interface()) {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		vw2, err := dw.WriteDocumentElement(desc.name)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
 | |
| 		err = encoder.EncodeValue(ectx, vw2, rv)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if sd.inlineMap >= 0 {
 | |
| 		rv := val.Field(sd.inlineMap)
 | |
| 		collisionFn := func(key string) bool {
 | |
| 			_, exists := sd.fm[key]
 | |
| 			return exists
 | |
| 		}
 | |
| 
 | |
| 		return defaultValueEncoders.mapEncodeValue(r, dw, rv, collisionFn)
 | |
| 	}
 | |
| 
 | |
| 	return dw.WriteDocumentEnd()
 | |
| }
 | |
| 
 | |
| // DecodeValue implements the Codec interface.
 | |
| // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
 | |
| // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
 | |
| func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
 | |
| 	if !val.CanSet() || val.Kind() != reflect.Struct {
 | |
| 		return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
 | |
| 	}
 | |
| 
 | |
| 	switch vr.Type() {
 | |
| 	case bsontype.Type(0), bsontype.EmbeddedDocument:
 | |
| 	default:
 | |
| 		return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type())
 | |
| 	}
 | |
| 
 | |
| 	sd, err := sc.describeStruct(r.Registry, val.Type())
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	var decoder ValueDecoder
 | |
| 	var inlineMap reflect.Value
 | |
| 	if sd.inlineMap >= 0 {
 | |
| 		inlineMap = val.Field(sd.inlineMap)
 | |
| 		if inlineMap.IsNil() {
 | |
| 			inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
 | |
| 		}
 | |
| 		decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	dr, err := vr.ReadDocument()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	for {
 | |
| 		name, vr, err := dr.ReadElement()
 | |
| 		if err == bsonrw.ErrEOD {
 | |
| 			break
 | |
| 		}
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		fd, exists := sd.fm[name]
 | |
| 		if !exists {
 | |
| 			// if the original name isn't found in the struct description, try again with the name in lowercase
 | |
| 			// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
 | |
| 			// names
 | |
| 			fd, exists = sd.fm[strings.ToLower(name)]
 | |
| 		}
 | |
| 
 | |
| 		if !exists {
 | |
| 			if sd.inlineMap < 0 {
 | |
| 				// The encoding/json package requires a flag to return on error for non-existent fields.
 | |
| 				// This functionality seems appropriate for the struct codec.
 | |
| 				err = vr.Skip()
 | |
| 				if err != nil {
 | |
| 					return err
 | |
| 				}
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			elem := reflect.New(inlineMap.Type().Elem()).Elem()
 | |
| 			err = decoder.DecodeValue(r, vr, elem)
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		var field reflect.Value
 | |
| 		if fd.inline == nil {
 | |
| 			field = val.Field(fd.idx)
 | |
| 		} else {
 | |
| 			field = val.FieldByIndex(fd.inline)
 | |
| 		}
 | |
| 
 | |
| 		if !field.CanSet() { // Being settable is a super set of being addressable.
 | |
| 			return fmt.Errorf("cannot decode element '%s' into field %v; it is not settable", name, field)
 | |
| 		}
 | |
| 		if field.Kind() == reflect.Ptr && field.IsNil() {
 | |
| 			field.Set(reflect.New(field.Type().Elem()))
 | |
| 		}
 | |
| 		field = field.Addr()
 | |
| 
 | |
| 		dctx := DecodeContext{Registry: r.Registry, Truncate: fd.truncate || r.Truncate}
 | |
| 		if fd.decoder == nil {
 | |
| 			return ErrNoDecoder{Type: field.Elem().Type()}
 | |
| 		}
 | |
| 
 | |
| 		if decoder, ok := fd.decoder.(ValueDecoder); ok {
 | |
| 			err = decoder.DecodeValue(dctx, vr, field.Elem())
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			continue
 | |
| 		}
 | |
| 		err = fd.decoder.DecodeValue(dctx, vr, field)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (sc *StructCodec) isZero(i interface{}) bool {
 | |
| 	v := reflect.ValueOf(i)
 | |
| 
 | |
| 	// check the value validity
 | |
| 	if !v.IsValid() {
 | |
| 		return true
 | |
| 	}
 | |
| 
 | |
| 	if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
 | |
| 		return z.IsZero()
 | |
| 	}
 | |
| 
 | |
| 	switch v.Kind() {
 | |
| 	case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
 | |
| 		return v.Len() == 0
 | |
| 	case reflect.Bool:
 | |
| 		return !v.Bool()
 | |
| 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 | |
| 		return v.Int() == 0
 | |
| 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
 | |
| 		return v.Uint() == 0
 | |
| 	case reflect.Float32, reflect.Float64:
 | |
| 		return v.Float() == 0
 | |
| 	case reflect.Interface, reflect.Ptr:
 | |
| 		return v.IsNil()
 | |
| 	}
 | |
| 
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| type structDescription struct {
 | |
| 	fm        map[string]fieldDescription
 | |
| 	fl        []fieldDescription
 | |
| 	inlineMap int
 | |
| }
 | |
| 
 | |
| type fieldDescription struct {
 | |
| 	name      string
 | |
| 	idx       int
 | |
| 	omitEmpty bool
 | |
| 	minSize   bool
 | |
| 	truncate  bool
 | |
| 	inline    []int
 | |
| 	encoder   ValueEncoder
 | |
| 	decoder   ValueDecoder
 | |
| }
 | |
| 
 | |
| func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
 | |
| 	// We need to analyze the struct, including getting the tags, collecting
 | |
| 	// information about inlining, and create a map of the field name to the field.
 | |
| 	sc.l.RLock()
 | |
| 	ds, exists := sc.cache[t]
 | |
| 	sc.l.RUnlock()
 | |
| 	if exists {
 | |
| 		return ds, nil
 | |
| 	}
 | |
| 
 | |
| 	numFields := t.NumField()
 | |
| 	sd := &structDescription{
 | |
| 		fm:        make(map[string]fieldDescription, numFields),
 | |
| 		fl:        make([]fieldDescription, 0, numFields),
 | |
| 		inlineMap: -1,
 | |
| 	}
 | |
| 
 | |
| 	for i := 0; i < numFields; i++ {
 | |
| 		sf := t.Field(i)
 | |
| 		if sf.PkgPath != "" {
 | |
| 			// unexported, ignore
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		encoder, err := r.LookupEncoder(sf.Type)
 | |
| 		if err != nil {
 | |
| 			encoder = nil
 | |
| 		}
 | |
| 		decoder, err := r.LookupDecoder(sf.Type)
 | |
| 		if err != nil {
 | |
| 			decoder = nil
 | |
| 		}
 | |
| 
 | |
| 		description := fieldDescription{idx: i, encoder: encoder, decoder: decoder}
 | |
| 
 | |
| 		stags, err := sc.parser.ParseStructTags(sf)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		if stags.Skip {
 | |
| 			continue
 | |
| 		}
 | |
| 		description.name = stags.Name
 | |
| 		description.omitEmpty = stags.OmitEmpty
 | |
| 		description.minSize = stags.MinSize
 | |
| 		description.truncate = stags.Truncate
 | |
| 
 | |
| 		if stags.Inline {
 | |
| 			switch sf.Type.Kind() {
 | |
| 			case reflect.Map:
 | |
| 				if sd.inlineMap >= 0 {
 | |
| 					return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
 | |
| 				}
 | |
| 				if sf.Type.Key() != tString {
 | |
| 					return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
 | |
| 				}
 | |
| 				sd.inlineMap = description.idx
 | |
| 			case reflect.Struct:
 | |
| 				inlinesf, err := sc.describeStruct(r, sf.Type)
 | |
| 				if err != nil {
 | |
| 					return nil, err
 | |
| 				}
 | |
| 				for _, fd := range inlinesf.fl {
 | |
| 					if _, exists := sd.fm[fd.name]; exists {
 | |
| 						return nil, fmt.Errorf("(struct %s) duplicated key %s", t.String(), fd.name)
 | |
| 					}
 | |
| 					if fd.inline == nil {
 | |
| 						fd.inline = []int{i, fd.idx}
 | |
| 					} else {
 | |
| 						fd.inline = append([]int{i}, fd.inline...)
 | |
| 					}
 | |
| 					sd.fm[fd.name] = fd
 | |
| 					sd.fl = append(sd.fl, fd)
 | |
| 				}
 | |
| 			default:
 | |
| 				return nil, fmt.Errorf("(struct %s) inline fields must be either a struct or a map", t.String())
 | |
| 			}
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		if _, exists := sd.fm[description.name]; exists {
 | |
| 			return nil, fmt.Errorf("struct %s) duplicated key %s", t.String(), description.name)
 | |
| 		}
 | |
| 
 | |
| 		sd.fm[description.name] = description
 | |
| 		sd.fl = append(sd.fl, description)
 | |
| 	}
 | |
| 
 | |
| 	sc.l.Lock()
 | |
| 	sc.cache[t] = sd
 | |
| 	sc.l.Unlock()
 | |
| 
 | |
| 	return sd, nil
 | |
| }
 |