package readline

import (
	"bufio"
	"bytes"
	"encoding/binary"
	"fmt"
	"io"
	"net"
	"os"
	"sync"
	"sync/atomic"
)

type MsgType int16

const (
	T_DATA = MsgType(iota)
	T_WIDTH
	T_WIDTH_REPORT
	T_ISTTY_REPORT
	T_RAW
	T_ERAW // exit raw
	T_EOF
)

type RemoteSvr struct {
	eof           int32
	closed        int32
	width         int32
	reciveChan    chan struct{}
	writeChan     chan *writeCtx
	conn          net.Conn
	isTerminal    bool
	funcWidthChan func()
	stopChan      chan struct{}

	dataBufM sync.Mutex
	dataBuf  bytes.Buffer
}

type writeReply struct {
	n   int
	err error
}

type writeCtx struct {
	msg   *Message
	reply chan *writeReply
}

func newWriteCtx(msg *Message) *writeCtx {
	return &writeCtx{
		msg:   msg,
		reply: make(chan *writeReply),
	}
}

func NewRemoteSvr(conn net.Conn) (*RemoteSvr, error) {
	rs := &RemoteSvr{
		width:      -1,
		conn:       conn,
		writeChan:  make(chan *writeCtx),
		reciveChan: make(chan struct{}),
		stopChan:   make(chan struct{}),
	}
	buf := bufio.NewReader(rs.conn)

	if err := rs.init(buf); err != nil {
		return nil, err
	}

	go rs.readLoop(buf)
	go rs.writeLoop()
	return rs, nil
}

func (r *RemoteSvr) init(buf *bufio.Reader) error {
	m, err := ReadMessage(buf)
	if err != nil {
		return err
	}
	// receive isTerminal
	if m.Type != T_ISTTY_REPORT {
		return fmt.Errorf("unexpected init message")
	}
	r.GotIsTerminal(m.Data)

	// receive width
	m, err = ReadMessage(buf)
	if err != nil {
		return err
	}
	if m.Type != T_WIDTH_REPORT {
		return fmt.Errorf("unexpected init message")
	}
	r.GotReportWidth(m.Data)

	return nil
}

func (r *RemoteSvr) HandleConfig(cfg *Config) {
	cfg.Stderr = r
	cfg.Stdout = r
	cfg.Stdin = r
	cfg.FuncExitRaw = r.ExitRawMode
	cfg.FuncIsTerminal = r.IsTerminal
	cfg.FuncMakeRaw = r.EnterRawMode
	cfg.FuncExitRaw = r.ExitRawMode
	cfg.FuncGetWidth = r.GetWidth
	cfg.FuncOnWidthChanged = func(f func()) {
		r.funcWidthChan = f
	}
}

func (r *RemoteSvr) IsTerminal() bool {
	return r.isTerminal
}

func (r *RemoteSvr) checkEOF() error {
	if atomic.LoadInt32(&r.eof) == 1 {
		return io.EOF
	}
	return nil
}

func (r *RemoteSvr) Read(b []byte) (int, error) {
	r.dataBufM.Lock()
	n, err := r.dataBuf.Read(b)
	r.dataBufM.Unlock()
	if n == 0 {
		if err := r.checkEOF(); err != nil {
			return 0, err
		}
	}

	if n == 0 && err == io.EOF {
		<-r.reciveChan
		r.dataBufM.Lock()
		n, err = r.dataBuf.Read(b)
		r.dataBufM.Unlock()
	}
	if n == 0 {
		if err := r.checkEOF(); err != nil {
			return 0, err
		}
	}

	return n, err
}

func (r *RemoteSvr) writeMsg(m *Message) error {
	ctx := newWriteCtx(m)
	r.writeChan <- ctx
	reply := <-ctx.reply
	return reply.err
}

func (r *RemoteSvr) Write(b []byte) (int, error) {
	ctx := newWriteCtx(NewMessage(T_DATA, b))
	r.writeChan <- ctx
	reply := <-ctx.reply
	return reply.n, reply.err
}

func (r *RemoteSvr) EnterRawMode() error {
	return r.writeMsg(NewMessage(T_RAW, nil))
}

func (r *RemoteSvr) ExitRawMode() error {
	return r.writeMsg(NewMessage(T_ERAW, nil))
}

func (r *RemoteSvr) writeLoop() {
	defer r.Close()

loop:
	for {
		select {
		case ctx, ok := <-r.writeChan:
			if !ok {
				break
			}
			n, err := ctx.msg.WriteTo(r.conn)
			ctx.reply <- &writeReply{n, err}
		case <-r.stopChan:
			break loop
		}
	}
}

func (r *RemoteSvr) Close() error {
	if atomic.CompareAndSwapInt32(&r.closed, 0, 1) {
		close(r.stopChan)
		r.conn.Close()
	}
	return nil
}

func (r *RemoteSvr) readLoop(buf *bufio.Reader) {
	defer r.Close()
	for {
		m, err := ReadMessage(buf)
		if err != nil {
			break
		}
		switch m.Type {
		case T_EOF:
			atomic.StoreInt32(&r.eof, 1)
			select {
			case r.reciveChan <- struct{}{}:
			default:
			}
		case T_DATA:
			r.dataBufM.Lock()
			r.dataBuf.Write(m.Data)
			r.dataBufM.Unlock()
			select {
			case r.reciveChan <- struct{}{}:
			default:
			}
		case T_WIDTH_REPORT:
			r.GotReportWidth(m.Data)
		case T_ISTTY_REPORT:
			r.GotIsTerminal(m.Data)
		}
	}
}

func (r *RemoteSvr) GotIsTerminal(data []byte) {
	if binary.BigEndian.Uint16(data) == 0 {
		r.isTerminal = false
	} else {
		r.isTerminal = true
	}
}

func (r *RemoteSvr) GotReportWidth(data []byte) {
	atomic.StoreInt32(&r.width, int32(binary.BigEndian.Uint16(data)))
	if r.funcWidthChan != nil {
		r.funcWidthChan()
	}
}

func (r *RemoteSvr) GetWidth() int {
	return int(atomic.LoadInt32(&r.width))
}

// -----------------------------------------------------------------------------

type Message struct {
	Type MsgType
	Data []byte
}

func ReadMessage(r io.Reader) (*Message, error) {
	m := new(Message)
	var length int32
	if err := binary.Read(r, binary.BigEndian, &length); err != nil {
		return nil, err
	}
	if err := binary.Read(r, binary.BigEndian, &m.Type); err != nil {
		return nil, err
	}
	m.Data = make([]byte, int(length)-2)
	if _, err := io.ReadFull(r, m.Data); err != nil {
		return nil, err
	}
	return m, nil
}

func NewMessage(t MsgType, data []byte) *Message {
	return &Message{t, data}
}

func (m *Message) WriteTo(w io.Writer) (int, error) {
	buf := bytes.NewBuffer(make([]byte, 0, len(m.Data)+2+4))
	binary.Write(buf, binary.BigEndian, int32(len(m.Data)+2))
	binary.Write(buf, binary.BigEndian, m.Type)
	buf.Write(m.Data)
	n, err := buf.WriteTo(w)
	return int(n), err
}

// -----------------------------------------------------------------------------

type RemoteCli struct {
	conn        net.Conn
	raw         RawMode
	receiveChan chan struct{}
	inited      int32
	isTerminal  *bool

	data  bytes.Buffer
	dataM sync.Mutex
}

func NewRemoteCli(conn net.Conn) (*RemoteCli, error) {
	r := &RemoteCli{
		conn:        conn,
		receiveChan: make(chan struct{}),
	}
	return r, nil
}

func (r *RemoteCli) MarkIsTerminal(is bool) {
	r.isTerminal = &is
}

func (r *RemoteCli) init() error {
	if !atomic.CompareAndSwapInt32(&r.inited, 0, 1) {
		return nil
	}

	if err := r.reportIsTerminal(); err != nil {
		return err
	}

	if err := r.reportWidth(); err != nil {
		return err
	}

	// register sig for width changed
	DefaultOnWidthChanged(func() {
		r.reportWidth()
	})
	return nil
}

func (r *RemoteCli) writeMsg(m *Message) error {
	r.dataM.Lock()
	_, err := m.WriteTo(r.conn)
	r.dataM.Unlock()
	return err
}

func (r *RemoteCli) Write(b []byte) (int, error) {
	m := NewMessage(T_DATA, b)
	r.dataM.Lock()
	_, err := m.WriteTo(r.conn)
	r.dataM.Unlock()
	return len(b), err
}

func (r *RemoteCli) reportWidth() error {
	screenWidth := GetScreenWidth()
	data := make([]byte, 2)
	binary.BigEndian.PutUint16(data, uint16(screenWidth))
	msg := NewMessage(T_WIDTH_REPORT, data)

	if err := r.writeMsg(msg); err != nil {
		return err
	}
	return nil
}

func (r *RemoteCli) reportIsTerminal() error {
	var isTerminal bool
	if r.isTerminal != nil {
		isTerminal = *r.isTerminal
	} else {
		isTerminal = DefaultIsTerminal()
	}
	data := make([]byte, 2)
	if isTerminal {
		binary.BigEndian.PutUint16(data, 1)
	} else {
		binary.BigEndian.PutUint16(data, 0)
	}
	msg := NewMessage(T_ISTTY_REPORT, data)
	if err := r.writeMsg(msg); err != nil {
		return err
	}
	return nil
}

func (r *RemoteCli) readLoop() {
	buf := bufio.NewReader(r.conn)
	for {
		msg, err := ReadMessage(buf)
		if err != nil {
			break
		}
		switch msg.Type {
		case T_ERAW:
			r.raw.Exit()
		case T_RAW:
			r.raw.Enter()
		case T_DATA:
			os.Stdout.Write(msg.Data)
		}
	}
}

func (r *RemoteCli) ServeBy(source io.Reader) error {
	if err := r.init(); err != nil {
		return err
	}

	go func() {
		defer r.Close()
		for {
			n, _ := io.Copy(r, source)
			if n == 0 {
				break
			}
		}
	}()
	defer r.raw.Exit()
	r.readLoop()
	return nil
}

func (r *RemoteCli) Close() {
	r.writeMsg(NewMessage(T_EOF, nil))
}

func (r *RemoteCli) Serve() error {
	return r.ServeBy(os.Stdin)
}

func ListenRemote(n, addr string, cfg *Config, h func(*Instance), onListen ...func(net.Listener) error) error {
	ln, err := net.Listen(n, addr)
	if err != nil {
		return err
	}
	if len(onListen) > 0 {
		if err := onListen[0](ln); err != nil {
			return err
		}
	}
	for {
		conn, err := ln.Accept()
		if err != nil {
			break
		}
		go func() {
			defer conn.Close()
			rl, err := HandleConn(*cfg, conn)
			if err != nil {
				return
			}
			h(rl)
		}()
	}
	return nil
}

func HandleConn(cfg Config, conn net.Conn) (*Instance, error) {
	r, err := NewRemoteSvr(conn)
	if err != nil {
		return nil, err
	}
	r.HandleConfig(&cfg)

	rl, err := NewEx(&cfg)
	if err != nil {
		return nil, err
	}
	return rl, nil
}

func DialRemote(n, addr string) error {
	conn, err := net.Dial(n, addr)
	if err != nil {
		return err
	}
	defer conn.Close()

	cli, err := NewRemoteCli(conn)
	if err != nil {
		return err
	}
	return cli.Serve()
}