Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 71 additions & 56 deletions src/encoding/xml/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,12 +416,6 @@ func (p *printer) popPrefix() {
}
}

var (
marshalerType = reflect.TypeFor[Marshaler]()
marshalerAttrType = reflect.TypeFor[MarshalerAttr]()
textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]()
)

// marshalValue writes one or more XML elements representing val.
// If val was obtained from a struct field, finfo must have its details.
func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplate *StartElement) error {
Expand Down Expand Up @@ -450,24 +444,32 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
typ := val.Type()

// Check for marshaler.
if val.CanInterface() && typ.Implements(marshalerType) {
return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
if val.CanInterface() {
if marshaler, ok := reflect.TypeAssert[Marshaler](val); ok {
return p.marshalInterface(marshaler, defaultStart(typ, finfo, startTemplate))
}
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerType) {
return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
if pv.CanInterface() {
if marshaler, ok := reflect.TypeAssert[Marshaler](pv); ok {
return p.marshalInterface(marshaler, defaultStart(pv.Type(), finfo, startTemplate))
}
}
}

// Check for text marshaler.
if val.CanInterface() && typ.Implements(textMarshalerType) {
return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
if val.CanInterface() {
if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](val); ok {
return p.marshalTextInterface(textMarshaler, defaultStart(typ, finfo, startTemplate))
}
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
if pv.CanInterface() {
if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](pv); ok {
return p.marshalTextInterface(textMarshaler, defaultStart(pv.Type(), finfo, startTemplate))
}
}
}

Expand Down Expand Up @@ -503,7 +505,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
} else {
fv := xmlname.value(val, dontInitNilPointers)
if v, ok := fv.Interface().(Name); ok && v.Local != "" {
if v, ok := reflect.TypeAssert[Name](fv); ok && v.Local != "" {
start.Name = v
}
}
Expand Down Expand Up @@ -580,21 +582,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat

// marshalAttr marshals an attribute with the given name and value, adding to start.Attr.
func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) error {
if val.CanInterface() && val.Type().Implements(marshalerAttrType) {
attr, err := val.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if err != nil {
return err
}
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
return nil
}

if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
if val.CanInterface() {
if marshaler, ok := reflect.TypeAssert[MarshalerAttr](val); ok {
attr, err := marshaler.MarshalXMLAttr(name)
if err != nil {
return err
}
Expand All @@ -605,19 +595,25 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
}
}

if val.CanInterface() && val.Type().Implements(textMarshalerType) {
text, err := val.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() {
if marshaler, ok := reflect.TypeAssert[MarshalerAttr](pv); ok {
attr, err := marshaler.MarshalXMLAttr(name)
if err != nil {
return err
}
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
return nil
}
}
start.Attr = append(start.Attr, Attr{name, string(text)})
return nil
}

if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if val.CanInterface() {
if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](val); ok {
text, err := textMarshaler.MarshalText()
if err != nil {
return err
}
Expand All @@ -626,6 +622,20 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
}
}

if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() {
if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](pv); ok {
text, err := textMarshaler.MarshalText()
if err != nil {
return err
}
start.Attr = append(start.Attr, Attr{name, string(text)})
return nil
}
}
}

// Dereference or skip nil pointer, interface values.
switch val.Kind() {
case reflect.Pointer, reflect.Interface:
Expand All @@ -647,7 +657,8 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
}

if val.Type() == attrType {
start.Attr = append(start.Attr, val.Interface().(Attr))
attr, _ := reflect.TypeAssert[Attr](val)
start.Attr = append(start.Attr, attr)
return nil
}

Expand Down Expand Up @@ -854,20 +865,9 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
if err := s.trim(finfo.parents); err != nil {
return err
}
if vf.CanInterface() && vf.Type().Implements(textMarshalerType) {
data, err := vf.Interface().(encoding.TextMarshaler).MarshalText()
if err != nil {
return err
}
if err := emit(p, data); err != nil {
return err
}
continue
}
if vf.CanAddr() {
pv := vf.Addr()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
data, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if vf.CanInterface() {
if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](vf); ok {
data, err := textMarshaler.MarshalText()
if err != nil {
return err
}
Expand All @@ -877,6 +877,21 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
continue
}
}
if vf.CanAddr() {
pv := vf.Addr()
if pv.CanInterface() {
if textMarshaler, ok := reflect.TypeAssert[encoding.TextMarshaler](pv); ok {
data, err := textMarshaler.MarshalText()
if err != nil {
return err
}
if err := emit(p, data); err != nil {
return err
}
continue
}
}
}

var scratch [64]byte
vf = indirect(vf)
Expand All @@ -902,7 +917,7 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
return err
}
case reflect.Slice:
if elem, ok := vf.Interface().([]byte); ok {
if elem, ok := reflect.TypeAssert[[]byte](vf); ok {
if err := emit(p, elem); err != nil {
return err
}
Expand Down
77 changes: 46 additions & 31 deletions src/encoding/xml/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,28 +255,36 @@ func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
}
val = val.Elem()
}
if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) {
if val.CanInterface() {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
if unmarshaler, ok := reflect.TypeAssert[UnmarshalerAttr](val); ok {
return unmarshaler.UnmarshalXMLAttr(attr)
}
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) {
return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
if pv.CanInterface() {
if unmarshaler, ok := reflect.TypeAssert[UnmarshalerAttr](pv); ok {
return unmarshaler.UnmarshalXMLAttr(attr)
}
}
}

// Not an UnmarshalerAttr; try encoding.TextUnmarshaler.
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
if val.CanInterface() {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](val); ok {
return textUnmarshaler.UnmarshalText([]byte(attr.Value))
}
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
if pv.CanInterface() {
if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](pv); ok {
return textUnmarshaler.UnmarshalText([]byte(attr.Value))
}
}
}

Expand All @@ -303,12 +311,7 @@ func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
return copyValue(val, []byte(attr.Value))
}

var (
attrType = reflect.TypeFor[Attr]()
unmarshalerType = reflect.TypeFor[Unmarshaler]()
unmarshalerAttrType = reflect.TypeFor[UnmarshalerAttr]()
textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
)
var attrType = reflect.TypeFor[Attr]()

const (
maxUnmarshalDepth = 10000
Expand Down Expand Up @@ -352,27 +355,35 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) e
val = val.Elem()
}

if val.CanInterface() && val.Type().Implements(unmarshalerType) {
if val.CanInterface() {
// This is an unmarshaler with a non-pointer receiver,
// so it's likely to be incorrect, but we do what we're told.
return d.unmarshalInterface(val.Interface().(Unmarshaler), start)
if unmarshaler, ok := reflect.TypeAssert[Unmarshaler](val); ok {
return d.unmarshalInterface(unmarshaler, start)
}
}

if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(unmarshalerType) {
return d.unmarshalInterface(pv.Interface().(Unmarshaler), start)
if pv.CanInterface() {
if unmarshaler, ok := reflect.TypeAssert[Unmarshaler](pv); ok {
return d.unmarshalInterface(unmarshaler, start)
}
}
}

if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
return d.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler))
if val.CanInterface() {
if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](val); ok {
return d.unmarshalTextInterface(textUnmarshaler)
}
}

if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
return d.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler))
if pv.CanInterface() {
if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](pv); ok {
return d.unmarshalTextInterface(textUnmarshaler)
}
}
}

Expand Down Expand Up @@ -453,7 +464,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) e
return UnmarshalError(e)
}
fv := finfo.value(sv, initNilPointers)
if _, ok := fv.Interface().(Name); ok {
if _, ok := reflect.TypeAssert[Name](fv); ok {
fv.Set(reflect.ValueOf(start.Name))
}
}
Expand Down Expand Up @@ -578,20 +589,24 @@ Loop:
}
}

if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) {
if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
if saveData.IsValid() && saveData.CanInterface() {
if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](saveData); ok {
if err := textUnmarshaler.UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
saveData = reflect.Value{}
}

if saveData.IsValid() && saveData.CanAddr() {
pv := saveData.Addr()
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
return err
if pv.CanInterface() {
if textUnmarshaler, ok := reflect.TypeAssert[encoding.TextUnmarshaler](pv); ok {
if err := textUnmarshaler.UnmarshalText(data); err != nil {
return err
}
saveData = reflect.Value{}
}
saveData = reflect.Value{}
}
}

Expand Down