package sereal

import (
	"encoding"
	"encoding/binary"
	"encoding/json"
	"errors"
	"fmt"
	"math"
	"reflect"
	"runtime"
	"strconv"
	"unsafe"
)

// An Encoder encodes Go data structures into Sereal byte streams
type Encoder struct {
	PerlCompat           bool       // try to mimic Perl's structure as much as possible
	Compression          compressor // optionally compress the main payload of the document using SnappyCompressor or ZlibCompressor
	CompressionThreshold int        // threshold in bytes above which compression is attempted: 1024 bytes by default
	DisableDedup         bool       // should we disable deduping of class names and hash keys
	DisableFREEZE        bool       // should we disable the FREEZE tag, which calls MarshalBinary
	ExpectedSize         uint       // give a hint to encoder about expected size of encoded data
	MaxRecursionDepth    int        // maximum recursion depth
	version              int        // default version to encode
	tcache               tagsCache
}

type marshalState struct {
	strTable map[string]int
	ptrTable map[uintptr]int
}

type compressor interface {
	compress(b []byte) ([]byte, error)
}

// NewEncoder returns a new Encoder struct with default values
func NewEncoder() *Encoder {
	return &Encoder{
		PerlCompat:           false,
		CompressionThreshold: 1024,
		version:              1,
	}
}

// NewEncoderV2 returns a new Encoder that encodes version 2
func NewEncoderV2() *Encoder {
	return &Encoder{
		PerlCompat:           false,
		CompressionThreshold: 1024,
		version:              2,
	}
}

// NewEncoderV3 returns a new Encoder that encodes version 3
func NewEncoderV3() *Encoder {
	return &Encoder{
		PerlCompat:           false,
		CompressionThreshold: 1024,
		version:              3,
	}
}

var defaultEncoder = NewEncoderV3()

// Marshal encodes body with the default encoder
func Marshal(body interface{}) ([]byte, error) {
	return defaultEncoder.MarshalWithHeader(nil, body)
}

// Marshal returns the Sereal encoding of body
func (e *Encoder) Marshal(body interface{}) (b []byte, err error) {
	return e.MarshalWithHeader(nil, body)
}

// MarshalWithHeader returns the Sereal encoding of body with header data
func (e *Encoder) MarshalWithHeader(header interface{}, body interface{}) (b []byte, err error) {
	defer func() {
		//return
		if r := recover(); r != nil {
			if _, ok := r.(runtime.Error); ok {
				panic(r)
			}

			if s, ok := r.(string); ok {
				err = errors.New(s)
			} else {
				err = r.(error)
			}
		}
	}()

	// uninitialized encoder? set to the most recent supported protocol version
	if e.version == 0 {
		e.version = ProtocolVersion
	}

	encHeader := make([]byte, headerSize, 32)

	if e.version < 3 {
		binary.LittleEndian.PutUint32(encHeader[:4], magicHeaderBytes)
	} else {
		binary.LittleEndian.PutUint32(encHeader[:4], magicHeaderBytesHighBit)
	}

	// Set the <version-type> component in the header
	encHeader[4] = byte(e.version) | byte(serealRaw)<<4

	if header != nil && e.version >= 2 {
		state := marshalState{
			strTable: make(map[string]int),
			ptrTable: make(map[uintptr]int),
		}

		// this is both the flag byte (== "there is user data") and also a hack to make 1-based offsets work
		henv := []byte{0x01} // flag byte == "there is user data"
		encHeaderSuffix, err := e.encode(henv, header, false, false, &state, remainingRecursion(e.MaxRecursionDepth))

		if err != nil {
			return nil, err
		}

		encHeader = varint(encHeader, uint(len(encHeaderSuffix)))
		encHeader = append(encHeader, encHeaderSuffix...)
	} else {
		/* header size */
		encHeader = append(encHeader, 0)
	}

	remainingDepth := remainingRecursion(e.MaxRecursionDepth)
	state := marshalState{
		strTable: make(map[string]int),
		ptrTable: make(map[uintptr]int),
	}

	encBody := make([]byte, 0, e.ExpectedSize)

	switch e.version {
	case 1:
		encBody = append(encBody, encHeader...) // hack for offsets
		encBody, err = e.encode(encBody, body, false, false, &state, remainingDepth)
		if len(encBody) >= len(encHeader) {
			encBody = encBody[len(encHeader):] // trim hacky bytes
		}
	case 2, 3:
		encBody = append(encBody, 0) // hack for 1-based offsets
		encBody, err = e.encode(encBody, body, false, false, &state, remainingDepth)
		if len(encBody) >= 1 {
			encBody = encBody[1:] // trim hacky first byte
		}
	}

	if err != nil {
		return nil, err
	}

	if e.Compression != nil && (e.CompressionThreshold == 0 || len(encBody) >= e.CompressionThreshold) {
		encBody, err = e.Compression.compress(encBody)
		if err != nil {
			return nil, err
		}

		var doctype documentType

		switch c := e.Compression.(type) {
		case SnappyCompressor:
			if e.version > 1 && !c.Incremental {
				return nil, errors.New("non-incremental snappy compression only valid for v1 documents")
			}
			if e.version == 1 {
				doctype = serealSnappy
			} else {
				doctype = serealSnappyIncremental
			}
		case ZlibCompressor:
			if e.version < 3 {
				return nil, errors.New("zlib compression only valid for v3 documents and up")
			}
			doctype = serealZlib
		default:
			// Defensive programming: this point should never be
			// reached in production code because the compressor
			// interface is not exported, hence no way to pass in
			// an unknown thing. But it may happen during
			// development when a new compressor is implemented,
			// but a relevant document type is not defined.
			panic("undefined compression")
		}

		encHeader[4] |= byte(doctype) << 4
	}

	return append(encHeader, encBody...), nil
}

/*************************************
 * Encode via static types - fast path
 *************************************/
func (e *Encoder) encode(b []byte, v interface{}, isKeyOrClass bool, isRefNext bool, state *marshalState, remainingDepth int) ([]byte, error) {
	var err error

	switch value := v.(type) {
	case nil:
		b = append(b, typeUNDEF)

	case bool:
		if value {
			b = append(b, typeTRUE)
		} else {
			b = append(b, typeFALSE)
		}

	case int:
		b = e.encodeInt(b, reflect.Int, int64(value))
	case int8:
		b = e.encodeInt(b, reflect.Int, int64(value))
	case int16:
		b = e.encodeInt(b, reflect.Int, int64(value))
	case int32:
		b = e.encodeInt(b, reflect.Int, int64(value))
	case int64:
		b = e.encodeInt(b, reflect.Int, value)

	case uint:
		b = e.encodeInt(b, reflect.Uint, int64(value))
	case uint8:
		b = e.encodeInt(b, reflect.Uint, int64(value))
	case uint16:
		b = e.encodeInt(b, reflect.Uint, int64(value))
	case uint32:
		b = e.encodeInt(b, reflect.Uint, int64(value))
	case uint64:
		b = e.encodeInt(b, reflect.Uint, int64(value))

	case float32:
		b = e.encodeFloat(b, value)
	case float64:
		b = e.encodeDouble(b, value)

	case json.Number:
		b = e.encodeJsonNumber(b, value, isKeyOrClass, state.strTable)

	case string:
		b = e.encodeString(b, value, isKeyOrClass, state.strTable)

	case []uint8:
		b = e.encodeBytes(b, value, isKeyOrClass, state.strTable)

	case []interface{}:
		b, err = e.encodeIntfArray(b, value, isRefNext, state, remainingDepth)

	case map[string]interface{}:
		b, err = e.encodeStrMap(b, value, isRefNext, state, remainingDepth)

	case reflect.Value:
		if value.Kind() == reflect.Invalid {
			b = append(b, typeUNDEF)
		} else {
			// could be optimized to tail call
			b, err = e.encode(b, value.Interface(), false, isRefNext, state, remainingDepth)
		}

	case PerlUndef:
		if value.canonical {
			b = append(b, typeCANONICAL_UNDEF)
		} else {
			b = append(b, typeUNDEF)
		}

	case PerlObject:
		b = append(b, typeOBJECT)
		b = e.encodeBytes(b, []byte(value.Class), true, state.strTable)
		b, err = e.encode(b, value.Reference, false, false, state, remainingDepth)

	case PerlRegexp:
		b = append(b, typeREGEXP)
		b = e.encodeBytes(b, value.Pattern, false, state.strTable)
		b = e.encodeBytes(b, value.Modifiers, false, state.strTable)

	case PerlWeakRef:
		b = append(b, typeWEAKEN)
		b, err = e.encode(b, value.Reference, false, false, state, remainingDepth)

	//case *interface{}:
	//TODO handle here if easy

	//case interface{}:
	// http://blog.golang.org/laws-of-reflection
	// One important detail is that the pair inside an interface always has the form (value, concrete type)
	// and cannot have the form (value, interface type). Interfaces do not hold interface values.
	//panic("interface cannot hold an interface")

	// ikruglov
	// in theory this block should no be commented,
	// but in practise type *interface{} somehow manages to match interface{}
	// if one manages to properly implement *interface{} case, this block should be uncommented

	default:
		b, err = e.encodeViaReflection(b, reflect.ValueOf(value), isKeyOrClass, isRefNext, state, remainingDepth)
	}

	return b, err
}

func (e *Encoder) encodeInt(by []byte, k reflect.Kind, i int64) []byte {
	switch {
	case 0 <= i && i <= 15:
		by = append(by, byte(i)&0x0f)
	case -16 <= i && i < 0 && k == reflect.Int:
		by = append(by, 0x010|(byte(i)&0x0f))
	case i > 15:
		by = append(by, typeVARINT)
		by = varint(by, uint(i))
	case i < 0:
		n := uint(i)
		if k == reflect.Int {
			by = append(by, typeZIGZAG)
			n = uint((i << 1) ^ (i >> 63))
		} else {
			by = append(by, typeVARINT)
		}

		by = varint(by, n)
	}

	return by
}

func (e *Encoder) encodeFloat(by []byte, f float32) []byte {
	u := math.Float32bits(f)
	return append(by, typeFLOAT, byte(u), byte(u>>8), byte(u>>16), byte(u>>24))
}

func (e *Encoder) encodeDouble(by []byte, f float64) []byte {
	u := math.Float64bits(f)
	return append(by, typeDOUBLE, byte(u), byte(u>>8), byte(u>>16), byte(u>>24), byte(u>>32), byte(u>>40), byte(u>>48), byte(u>>56))
}

func (e *Encoder) encodeJsonNumber(by []byte, n json.Number, isKeyOrClass bool, strTable map[string]int) []byte {
	int64Value, err := n.Int64()
	if err == nil {
		return e.encodeInt(by, reflect.Int, int64Value)
	} else {

		// we do not want to lose precision for large integers, as those are often IDs or hashsums of things
		// so if the error we got is about the number not fitting into the int64 datatype, better keep it as a string instead of converting to float and losing precision
		// uint64 values which don't fit to signed int are a common example, but the value might be of arbitrary size

		numErr, isNumErr := err.(*strconv.NumError)
		if isNumErr && numErr.Err == strconv.ErrRange {
			return e.encodeString(by, n.String(), isKeyOrClass, strTable)
		}
	}

	if float64Value, err := n.Float64(); err == nil {
		return e.encodeDouble(by, float64Value)
	}

	return e.encodeString(by, n.String(), isKeyOrClass, strTable)
}

func (e *Encoder) encodeString(by []byte, s string, isKeyOrClass bool, strTable map[string]int) []byte {
	if !e.DisableDedup && isKeyOrClass {
		if copyOffs, ok := strTable[s]; ok {
			by = append(by, typeCOPY)
			by = varint(by, uint(copyOffs))
			return by
		}
		strTable[s] = len(by)
	}

	by = append(by, typeSTR_UTF8)
	by = varint(by, uint(len(s)))
	return append(by, s...)
}

func (e *Encoder) encodeBytes(by []byte, byt []byte, isKeyOrClass bool, strTable map[string]int) []byte {
	if !e.DisableDedup && isKeyOrClass {
		if copyOffs, ok := strTable[string(byt)]; ok {
			by = append(by, typeCOPY)
			by = varint(by, uint(copyOffs))
			return by
		}
		// save for later
		strTable[string(byt)] = len(by)
	}

	if l := len(byt); l < 32 {
		by = append(by, typeSHORT_BINARY_0+byte(l))
	} else {
		by = append(by, typeBINARY)
		by = varint(by, uint(l))
	}

	return append(by, byt...)
}

func (e *Encoder) encodeIntfArray(by []byte, arr []interface{}, isRefNext bool, state *marshalState, remainingDepth int) ([]byte, error) {
	remainingDepth--
	if remainingDepth <= 0 {
		return nil, ErrMaxRecursionLimit
	}

	if e.PerlCompat && !isRefNext {
		by = append(by, typeREFN)
	}

	// TODO implement ARRAYREF for small arrays

	l := len(arr)
	by = append(by, typeARRAY)
	by = varint(by, uint(l))

	var err error
	for i := 0; i < l; i++ {
		if by, err = e.encode(by, arr[i], false, false, state, remainingDepth); err != nil {
			return nil, err
		}
	}

	return by, nil
}

func (e *Encoder) encodeStrMap(by []byte, m map[string]interface{}, isRefNext bool, state *marshalState, remainingDepth int) ([]byte, error) {
	remainingDepth--
	if remainingDepth <= 0 {
		return nil, ErrMaxRecursionLimit
	}

	if e.PerlCompat && !isRefNext {
		by = append(by, typeREFN)
	}

	// TODO implement HASHREF for small maps

	by = append(by, typeHASH)
	by = varint(by, uint(len(m)))

	var err error
	for k, v := range m {
		by = e.encodeString(by, k, true, state.strTable)
		if by, err = e.encode(by, v, false, false, state, remainingDepth); err != nil {
			return by, err
		}
	}

	return by, nil
}

/*************************************
 * Encode via reflection
 *************************************/
func (e *Encoder) encodeViaReflection(b []byte, rv reflect.Value, isKeyOrClass bool, isRefNext bool, state *marshalState, remainingDepth int) ([]byte, error) {
	if !e.DisableFREEZE && rv.Kind() != reflect.Invalid && rv.Kind() != reflect.Ptr {
		if m, ok := rv.Interface().(encoding.BinaryMarshaler); ok {
			by, err := m.MarshalBinary()
			if err != nil {
				return nil, err
			}

			b = append(b, typeOBJECT_FREEZE)
			b = e.encodeString(b, concreteName(rv), true, state.strTable)
			b = append(b, typeREFN)
			b = append(b, typeARRAY)
			b = varint(b, uint(1))
			return e.encode(b, reflect.ValueOf(by), false, false, state, remainingDepth)
		}
	}

	// make sure we're looking at a real type and not an interface
	for rv.Kind() == reflect.Interface {
		rv = rv.Elem()
	}

	var err error
	switch rk := rv.Kind(); rk {
	case reflect.Slice:
		// uint8 case is handled in encode()
		fallthrough

	case reflect.Array:
		b, err = e.encodeArray(b, rv, isRefNext, state, remainingDepth)

	case reflect.Map:
		b, err = e.encodeMap(b, rv, isRefNext, state, remainingDepth)

	case reflect.Struct:
		b, err = e.encodeStruct(b, rv, state, remainingDepth)

	case reflect.Ptr:
		b, err = e.encodePointer(b, rv, state, remainingDepth)

	case reflect.Bool:
		if rv.Bool() {
			b = append(b, typeTRUE)
		} else {
			b = append(b, typeFALSE)
		}

	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		b = e.encodeInt(b, rk, rv.Int())

	case reflect.Float32:
		b = e.encodeFloat(b, float32(rv.Float()))

	case reflect.Float64:
		b = e.encodeDouble(b, rv.Float())

	case reflect.String:
		b = e.encodeString(b, rv.String(), isKeyOrClass, state.strTable)

	default:
		panic(fmt.Sprintf("no support for type '%s' (%s)", rk.String(), rv.Type()))
	}

	return b, err
}

func (e *Encoder) encodeArray(by []byte, arr reflect.Value, isRefNext bool, state *marshalState, remainingDepth int) ([]byte, error) {
	remainingDepth--
	if remainingDepth <= 0 {
		return nil, ErrMaxRecursionLimit
	}

	if e.PerlCompat && !isRefNext {
		by = append(by, typeREFN)
	}

	l := arr.Len()
	by = append(by, typeARRAY)
	by = varint(by, uint(l))

	var err error
	for i := 0; i < l; i++ {
		if by, err = e.encode(by, arr.Index(i), false, false, state, remainingDepth); err != nil {
			return nil, err
		}
	}

	return by, nil
}

func (e *Encoder) encodeMap(by []byte, m reflect.Value, isRefNext bool, state *marshalState, remainingDepth int) ([]byte, error) {
	remainingDepth--
	if remainingDepth <= 0 {
		return nil, ErrMaxRecursionLimit
	}

	if e.PerlCompat && !isRefNext {
		by = append(by, typeREFN)
	}

	keys := m.MapKeys()
	by = append(by, typeHASH)
	by = varint(by, uint(len(keys)))

	if e.PerlCompat {
		var err error
		for _, k := range keys {
			by = e.encodeString(by, k.String(), true, state.strTable)
			if by, err = e.encode(by, m.MapIndex(k), false, false, state, remainingDepth); err != nil {
				return by, err
			}
		}
	} else {
		var err error
		for _, k := range keys {
			if by, err = e.encode(by, k, true, false, state, remainingDepth); err != nil {
				return by, err
			}

			if by, err = e.encode(by, m.MapIndex(k), false, false, state, remainingDepth); err != nil {
				return by, err
			}
		}
	}

	return by, nil
}

func (e *Encoder) encodeStruct(by []byte, st reflect.Value, state *marshalState, remainingDepth int) ([]byte, error) {
	remainingDepth--
	if remainingDepth <= 0 {
		return nil, ErrMaxRecursionLimit
	}

	tags := e.tcache.Get(st)

	by = append(by, typeOBJECT)
	by = e.encodeBytes(by, []byte(st.Type().Name()), true, state.strTable)

	if e.PerlCompat {
		// must be a reference
		by = append(by, typeREFN)
	}

	by = append(by, typeHASH)
	by = varint(by, uint(len(tags)))

	var err error
	for f, i := range tags {
		by = e.encodeString(by, f, true, state.strTable)
		if by, err = e.encode(by, st.Field(i), false, false, state, remainingDepth); err != nil {
			return nil, err
		}
	}

	return by, nil
}

func (e *Encoder) encodePointer(by []byte, rv reflect.Value, state *marshalState, remainingDepth int) ([]byte, error) {
	remainingDepth--
	if remainingDepth <= 0 {
		return nil, ErrMaxRecursionLimit
	}

	// ikruglov
	// I don't fully understand this logic, so leave it as is :-)

	if rv.Elem().Kind() == reflect.Struct {
		switch rv.Elem().Interface().(type) {
		case PerlRegexp:
			return e.encode(by, rv.Elem(), false, false, state, remainingDepth)
		case PerlUndef:
			return e.encode(by, rv.Elem(), false, false, state, remainingDepth)
		case PerlObject:
			return e.encode(by, rv.Elem(), false, false, state, remainingDepth)
		case PerlWeakRef:
			return e.encode(by, rv.Elem(), false, false, state, remainingDepth)
		}
	}

	rvptr := rv.Pointer()
	rvptr2 := getPointer(rv.Elem())

	offs, ok := state.ptrTable[rvptr]

	if !ok && rvptr2 != 0 {
		offs, ok = state.ptrTable[rvptr2]
		if ok {
			rvptr = rvptr2
		}
	}

	if ok { // seen this before
		by = append(by, typeREFP)
		by = varint(by, uint(offs))
		by[offs] |= trackFlag // original offset now tracked
	} else {
		lenbOrig := len(by)

		by = append(by, typeREFN)

		if rvptr != 0 {
			state.ptrTable[rvptr] = lenbOrig
		}

		var err error
		by, err = e.encode(by, rv.Elem(), false, true, state, remainingDepth)
		if err != nil {
			return nil, err
		}

		if rvptr2 != 0 {
			// The thing this this points to starts one after the current pointer
			state.ptrTable[rvptr2] = lenbOrig + 1
		}
	}

	return by, nil
}

func varint(by []byte, n uint) []uint8 {
	for n >= 0x80 {
		b := byte(n) | 0x80
		by = append(by, b)
		n >>= 7
	}

	return append(by, byte(n))
}

func getPointer(rv reflect.Value) uintptr {
	var rvptr uintptr

	switch rv.Kind() {
	case reflect.Map, reflect.Slice:
		rvptr = rv.Pointer()
	case reflect.Interface:
		// FIXME: still needed?
		return getPointer(rv.Elem())
	case reflect.Ptr:
		rvptr = rv.Pointer()
	case reflect.String:
		ps := (*reflect.StringHeader)(unsafe.Pointer(rv.UnsafeAddr()))
		rvptr = ps.Data
	}

	return rvptr
}

func concreteName(value reflect.Value) string {
	return value.Type().PkgPath() + "." + value.Type().Name()
}

func remainingRecursion(maxDepth int) int {
	if maxDepth <= 0 {
		return 2147483647
	}

	return maxDepth
}
