// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package proto

import (
	"bytes"
	"encoding"
	"fmt"
	"io"
	"math"
	"sort"
	"strings"

	"google.golang.org/protobuf/encoding/prototext"
	"google.golang.org/protobuf/encoding/protowire"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/reflect/protoregistry"
)

const wrapTextMarshalV2 = false

// TextMarshaler is a configurable text format marshaler.
type TextMarshaler struct {
	Compact   bool // use compact text format (one line)
	ExpandAny bool // expand google.protobuf.Any messages of known types
}

// Marshal writes the proto text format of m to w.
func (tm *TextMarshaler) Marshal(w io.Writer, m Message) error {
	b, err := tm.marshal(m)
	if len(b) > 0 {
		if _, err := w.Write(b); err != nil {
			return err
		}
	}
	return err
}

// Text returns a proto text formatted string of m.
func (tm *TextMarshaler) Text(m Message) string {
	b, _ := tm.marshal(m)
	return string(b)
}

func (tm *TextMarshaler) marshal(m Message) ([]byte, error) {
	mr := MessageReflect(m)
	if mr == nil || !mr.IsValid() {
		return []byte("<nil>"), nil
	}

	if wrapTextMarshalV2 {
		if m, ok := m.(encoding.TextMarshaler); ok {
			return m.MarshalText()
		}

		opts := prototext.MarshalOptions{
			AllowPartial: true,
			EmitUnknown:  true,
		}
		if !tm.Compact {
			opts.Indent = "  "
		}
		if !tm.ExpandAny {
			opts.Resolver = (*protoregistry.Types)(nil)
		}
		return opts.Marshal(mr.Interface())
	} else {
		w := &textWriter{
			compact:   tm.Compact,
			expandAny: tm.ExpandAny,
			complete:  true,
		}

		if m, ok := m.(encoding.TextMarshaler); ok {
			b, err := m.MarshalText()
			if err != nil {
				return nil, err
			}
			w.Write(b)
			return w.buf, nil
		}

		err := w.writeMessage(mr)
		return w.buf, err
	}
}

var (
	defaultTextMarshaler = TextMarshaler{}
	compactTextMarshaler = TextMarshaler{Compact: true}
)

// MarshalText writes the proto text format of m to w.
func MarshalText(w io.Writer, m Message) error { return defaultTextMarshaler.Marshal(w, m) }

// MarshalTextString returns a proto text formatted string of m.
func MarshalTextString(m Message) string { return defaultTextMarshaler.Text(m) }

// CompactText writes the compact proto text format of m to w.
func CompactText(w io.Writer, m Message) error { return compactTextMarshaler.Marshal(w, m) }

// CompactTextString returns a compact proto text formatted string of m.
func CompactTextString(m Message) string { return compactTextMarshaler.Text(m) }

var (
	newline         = []byte("\n")
	endBraceNewline = []byte("}\n")
	posInf          = []byte("inf")
	negInf          = []byte("-inf")
	nan             = []byte("nan")
)

// textWriter is an io.Writer that tracks its indentation level.
type textWriter struct {
	compact   bool // same as TextMarshaler.Compact
	expandAny bool // same as TextMarshaler.ExpandAny
	complete  bool // whether the current position is a complete line
	indent    int  // indentation level; never negative
	buf       []byte
}

func (w *textWriter) Write(p []byte) (n int, _ error) {
	newlines := bytes.Count(p, newline)
	if newlines == 0 {
		if !w.compact && w.complete {
			w.writeIndent()
		}
		w.buf = append(w.buf, p...)
		w.complete = false
		return len(p), nil
	}

	frags := bytes.SplitN(p, newline, newlines+1)
	if w.compact {
		for i, frag := range frags {
			if i > 0 {
				w.buf = append(w.buf, ' ')
				n++
			}
			w.buf = append(w.buf, frag...)
			n += len(frag)
		}
		return n, nil
	}

	for i, frag := range frags {
		if w.complete {
			w.writeIndent()
		}
		w.buf = append(w.buf, frag...)
		n += len(frag)
		if i+1 < len(frags) {
			w.buf = append(w.buf, '\n')
			n++
		}
	}
	w.complete = len(frags[len(frags)-1]) == 0
	return n, nil
}

func (w *textWriter) WriteByte(c byte) error {
	if w.compact && c == '\n' {
		c = ' '
	}
	if !w.compact && w.complete {
		w.writeIndent()
	}
	w.buf = append(w.buf, c)
	w.complete = c == '\n'
	return nil
}

func (w *textWriter) writeName(fd protoreflect.FieldDescriptor) {
	if !w.compact && w.complete {
		w.writeIndent()
	}
	w.complete = false

	if fd.Kind() != protoreflect.GroupKind {
		w.buf = append(w.buf, fd.Name()...)
		w.WriteByte(':')
	} else {
		// Use message type name for group field name.
		w.buf = append(w.buf, fd.Message().Name()...)
	}

	if !w.compact {
		w.WriteByte(' ')
	}
}

func requiresQuotes(u string) bool {
	// When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
	for _, ch := range u {
		switch {
		case ch == '.' || ch == '/' || ch == '_':
			continue
		case '0' <= ch && ch <= '9':
			continue
		case 'A' <= ch && ch <= 'Z':
			continue
		case 'a' <= ch && ch <= 'z':
			continue
		default:
			return true
		}
	}
	return false
}

// writeProto3Any writes an expanded google.protobuf.Any message.
//
// It returns (false, nil) if sv value can't be unmarshaled (e.g. because
// required messages are not linked in).
//
// It returns (true, error) when sv was written in expanded format or an error
// was encountered.
func (w *textWriter) writeProto3Any(m protoreflect.Message) (bool, error) {
	md := m.Descriptor()
	fdURL := md.Fields().ByName("type_url")
	fdVal := md.Fields().ByName("value")

	url := m.Get(fdURL).String()
	mt, err := protoregistry.GlobalTypes.FindMessageByURL(url)
	if err != nil {
		return false, nil
	}

	b := m.Get(fdVal).Bytes()
	m2 := mt.New()
	if err := proto.Unmarshal(b, m2.Interface()); err != nil {
		return false, nil
	}
	w.Write([]byte("["))
	if requiresQuotes(url) {
		w.writeQuotedString(url)
	} else {
		w.Write([]byte(url))
	}
	if w.compact {
		w.Write([]byte("]:<"))
	} else {
		w.Write([]byte("]: <\n"))
		w.indent++
	}
	if err := w.writeMessage(m2); err != nil {
		return true, err
	}
	if w.compact {
		w.Write([]byte("> "))
	} else {
		w.indent--
		w.Write([]byte(">\n"))
	}
	return true, nil
}

func (w *textWriter) writeMessage(m protoreflect.Message) error {
	md := m.Descriptor()
	if w.expandAny && md.FullName() == "google.protobuf.Any" {
		if canExpand, err := w.writeProto3Any(m); canExpand {
			return err
		}
	}

	fds := md.Fields()
	for i := 0; i < fds.Len(); {
		fd := fds.Get(i)
		if od := fd.ContainingOneof(); od != nil {
			fd = m.WhichOneof(od)
			i += od.Fields().Len()
		} else {
			i++
		}
		if fd == nil || !m.Has(fd) {
			continue
		}

		switch {
		case fd.IsList():
			lv := m.Get(fd).List()
			for j := 0; j < lv.Len(); j++ {
				w.writeName(fd)
				v := lv.Get(j)
				if err := w.writeSingularValue(v, fd); err != nil {
					return err
				}
				w.WriteByte('\n')
			}
		case fd.IsMap():
			kfd := fd.MapKey()
			vfd := fd.MapValue()
			mv := m.Get(fd).Map()

			type entry struct{ key, val protoreflect.Value }
			var entries []entry
			mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
				entries = append(entries, entry{k.Value(), v})
				return true
			})
			sort.Slice(entries, func(i, j int) bool {
				switch kfd.Kind() {
				case protoreflect.BoolKind:
					return !entries[i].key.Bool() && entries[j].key.Bool()
				case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
					return entries[i].key.Int() < entries[j].key.Int()
				case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
					return entries[i].key.Uint() < entries[j].key.Uint()
				case protoreflect.StringKind:
					return entries[i].key.String() < entries[j].key.String()
				default:
					panic("invalid kind")
				}
			})
			for _, entry := range entries {
				w.writeName(fd)
				w.WriteByte('<')
				if !w.compact {
					w.WriteByte('\n')
				}
				w.indent++
				w.writeName(kfd)
				if err := w.writeSingularValue(entry.key, kfd); err != nil {
					return err
				}
				w.WriteByte('\n')
				w.writeName(vfd)
				if err := w.writeSingularValue(entry.val, vfd); err != nil {
					return err
				}
				w.WriteByte('\n')
				w.indent--
				w.WriteByte('>')
				w.WriteByte('\n')
			}
		default:
			w.writeName(fd)
			if err := w.writeSingularValue(m.Get(fd), fd); err != nil {
				return err
			}
			w.WriteByte('\n')
		}
	}

	if b := m.GetUnknown(); len(b) > 0 {
		w.writeUnknownFields(b)
	}
	return w.writeExtensions(m)
}

func (w *textWriter) writeSingularValue(v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
	switch fd.Kind() {
	case protoreflect.FloatKind, protoreflect.DoubleKind:
		switch vf := v.Float(); {
		case math.IsInf(vf, +1):
			w.Write(posInf)
		case math.IsInf(vf, -1):
			w.Write(negInf)
		case math.IsNaN(vf):
			w.Write(nan)
		default:
			fmt.Fprint(w, v.Interface())
		}
	case protoreflect.StringKind:
		// NOTE: This does not validate UTF-8 for historical reasons.
		w.writeQuotedString(string(v.String()))
	case protoreflect.BytesKind:
		w.writeQuotedString(string(v.Bytes()))
	case protoreflect.MessageKind, protoreflect.GroupKind:
		var bra, ket byte = '<', '>'
		if fd.Kind() == protoreflect.GroupKind {
			bra, ket = '{', '}'
		}
		w.WriteByte(bra)
		if !w.compact {
			w.WriteByte('\n')
		}
		w.indent++
		m := v.Message()
		if m2, ok := m.Interface().(encoding.TextMarshaler); ok {
			b, err := m2.MarshalText()
			if err != nil {
				return err
			}
			w.Write(b)
		} else {
			w.writeMessage(m)
		}
		w.indent--
		w.WriteByte(ket)
	case protoreflect.EnumKind:
		if ev := fd.Enum().Values().ByNumber(v.Enum()); ev != nil {
			fmt.Fprint(w, ev.Name())
		} else {
			fmt.Fprint(w, v.Enum())
		}
	default:
		fmt.Fprint(w, v.Interface())
	}
	return nil
}

// writeQuotedString writes a quoted string in the protocol buffer text format.
func (w *textWriter) writeQuotedString(s string) {
	w.WriteByte('"')
	for i := 0; i < len(s); i++ {
		switch c := s[i]; c {
		case '\n':
			w.buf = append(w.buf, `\n`...)
		case '\r':
			w.buf = append(w.buf, `\r`...)
		case '\t':
			w.buf = append(w.buf, `\t`...)
		case '"':
			w.buf = append(w.buf, `\"`...)
		case '\\':
			w.buf = append(w.buf, `\\`...)
		default:
			if isPrint := c >= 0x20 && c < 0x7f; isPrint {
				w.buf = append(w.buf, c)
			} else {
				w.buf = append(w.buf, fmt.Sprintf(`\%03o`, c)...)
			}
		}
	}
	w.WriteByte('"')
}

func (w *textWriter) writeUnknownFields(b []byte) {
	if !w.compact {
		fmt.Fprintf(w, "/* %d unknown bytes */\n", len(b))
	}

	for len(b) > 0 {
		num, wtyp, n := protowire.ConsumeTag(b)
		if n < 0 {
			return
		}
		b = b[n:]

		if wtyp == protowire.EndGroupType {
			w.indent--
			w.Write(endBraceNewline)
			continue
		}
		fmt.Fprint(w, num)
		if wtyp != protowire.StartGroupType {
			w.WriteByte(':')
		}
		if !w.compact || wtyp == protowire.StartGroupType {
			w.WriteByte(' ')
		}
		switch wtyp {
		case protowire.VarintType:
			v, n := protowire.ConsumeVarint(b)
			if n < 0 {
				return
			}
			b = b[n:]
			fmt.Fprint(w, v)
		case protowire.Fixed32Type:
			v, n := protowire.ConsumeFixed32(b)
			if n < 0 {
				return
			}
			b = b[n:]
			fmt.Fprint(w, v)
		case protowire.Fixed64Type:
			v, n := protowire.ConsumeFixed64(b)
			if n < 0 {
				return
			}
			b = b[n:]
			fmt.Fprint(w, v)
		case protowire.BytesType:
			v, n := protowire.ConsumeBytes(b)
			if n < 0 {
				return
			}
			b = b[n:]
			fmt.Fprintf(w, "%q", v)
		case protowire.StartGroupType:
			w.WriteByte('{')
			w.indent++
		default:
			fmt.Fprintf(w, "/* unknown wire type %d */", wtyp)
		}
		w.WriteByte('\n')
	}
}

// writeExtensions writes all the extensions in m.
func (w *textWriter) writeExtensions(m protoreflect.Message) error {
	md := m.Descriptor()
	if md.ExtensionRanges().Len() == 0 {
		return nil
	}

	type ext struct {
		desc protoreflect.FieldDescriptor
		val  protoreflect.Value
	}
	var exts []ext
	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
		if fd.IsExtension() {
			exts = append(exts, ext{fd, v})
		}
		return true
	})
	sort.Slice(exts, func(i, j int) bool {
		return exts[i].desc.Number() < exts[j].desc.Number()
	})

	for _, ext := range exts {
		// For message set, use the name of the message as the extension name.
		name := string(ext.desc.FullName())
		if isMessageSet(ext.desc.ContainingMessage()) {
			name = strings.TrimSuffix(name, ".message_set_extension")
		}

		if !ext.desc.IsList() {
			if err := w.writeSingularExtension(name, ext.val, ext.desc); err != nil {
				return err
			}
		} else {
			lv := ext.val.List()
			for i := 0; i < lv.Len(); i++ {
				if err := w.writeSingularExtension(name, lv.Get(i), ext.desc); err != nil {
					return err
				}
			}
		}
	}
	return nil
}

func (w *textWriter) writeSingularExtension(name string, v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
	fmt.Fprintf(w, "[%s]:", name)
	if !w.compact {
		w.WriteByte(' ')
	}
	if err := w.writeSingularValue(v, fd); err != nil {
		return err
	}
	w.WriteByte('\n')
	return nil
}

func (w *textWriter) writeIndent() {
	if !w.complete {
		return
	}
	for i := 0; i < w.indent*2; i++ {
		w.buf = append(w.buf, ' ')
	}
	w.complete = false
}