// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package exemplar

import (
	"context"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/trace"
)

// Sat Jan 01 2000 00:00:00 GMT+0000.
var staticTime = time.Unix(946684800, 0)

type factory func(requestedCap int) (r ReservoirProvider, actualCap int)

func ReservoirTest[N int64 | float64](f factory) func(*testing.T) {
	return func(t *testing.T) {
		t.Helper()

		ctx := t.Context()

		t.Run("CaptureSpanContext", func(t *testing.T) {
			t.Helper()

			rp, n := f(1)
			if n < 1 {
				t.Skip("skipping, reservoir capacity less than 1:", n)
			}
			r := rp(*attribute.EmptySet())

			tID, sID := trace.TraceID{0x01}, trace.SpanID{0x01}
			sc := trace.NewSpanContext(trace.SpanContextConfig{
				TraceID:    tID,
				SpanID:     sID,
				TraceFlags: trace.FlagsSampled,
			})
			ctx := trace.ContextWithSpanContext(ctx, sc)

			r.Offer(ctx, staticTime, NewValue(N(10)), nil)

			var dest []Exemplar
			r.Collect(&dest)

			want := Exemplar{
				Time:    staticTime,
				Value:   NewValue(N(10)),
				SpanID:  sID[:],
				TraceID: tID[:],
			}
			require.Len(t, dest, 1, "number of collected exemplars")
			assert.Equal(t, want, dest[0])
		})

		t.Run("FilterAttributes", func(t *testing.T) {
			t.Helper()

			rp, n := f(1)
			if n < 1 {
				t.Skip("skipping, reservoir capacity less than 1:", n)
			}
			r := rp(*attribute.EmptySet())

			adminTrue := attribute.Bool("admin", true)
			r.Offer(ctx, staticTime, NewValue(N(10)), []attribute.KeyValue{adminTrue})

			var dest []Exemplar
			r.Collect(&dest)

			want := Exemplar{
				FilteredAttributes: []attribute.KeyValue{adminTrue},
				Time:               staticTime,
				Value:              NewValue(N(10)),
			}
			require.Len(t, dest, 1, "number of collected exemplars")
			assert.Equal(t, want, dest[0])
		})

		t.Run("CollectLessThanN", func(t *testing.T) {
			t.Helper()

			rp, n := f(2)
			if n < 2 {
				t.Skip("skipping, reservoir capacity less than 2:", n)
			}
			r := rp(*attribute.EmptySet())

			r.Offer(ctx, staticTime, NewValue(N(10)), nil)

			var dest []Exemplar
			r.Collect(&dest)
			// No empty exemplars are exported.
			require.Len(t, dest, 1, "number of collected exemplars")
		})

		t.Run("MultipleOffers", func(t *testing.T) {
			t.Helper()

			rp, n := f(3)
			if n < 1 {
				t.Skip("skipping, reservoir capacity less than 1:", n)
			}
			r := rp(*attribute.EmptySet())

			for i := 0; i < n+1; i++ {
				v := NewValue(N(i))
				r.Offer(ctx, staticTime, v, nil)
			}

			var dest []Exemplar
			r.Collect(&dest)
			assert.Len(t, dest, n, "multiple offers did not fill reservoir")

			// Ensure the collect reset also resets any counting state.
			for i := 0; i < n+1; i++ {
				v := NewValue(N(i))
				r.Offer(ctx, staticTime, v, nil)
			}

			dest = dest[:0]
			r.Collect(&dest)
			assert.Len(t, dest, n, "internal count state not reset")
		})

		t.Run("DropAll", func(t *testing.T) {
			t.Helper()

			rp, n := f(0)
			if n > 0 {
				t.Skip("skipping, reservoir capacity greater than 0:", n)
			}
			r := rp(*attribute.EmptySet())

			r.Offer(t.Context(), staticTime, NewValue(N(10)), nil)

			dest := []Exemplar{{}} // Should be reset to empty.
			r.Collect(&dest)
			assert.Empty(t, dest, "no exemplars should be collected")
		})

		t.Run("Negative reservoir capacity drops all", func(t *testing.T) {
			t.Helper()

			rp, n := f(-1)
			if n > 0 {
				t.Skip("skipping, reservoir capacity greater than 0:", n)
			}
			assert.Zero(t, n)
			r := rp(*attribute.EmptySet())

			r.Offer(t.Context(), staticTime, NewValue(N(10)), nil)

			dest := []Exemplar{{}} // Should be reset to empty.
			r.Collect(&dest)
			assert.Empty(t, dest, "no exemplars should be collected")
		})
	}
}

func reservoirConcurrentSafeTest[N int64 | float64](f factory) func(*testing.T) {
	return func(t *testing.T) {
		t.Helper()
		rp, n := f(1)
		if n < 1 {
			t.Skip("skipping, reservoir capacity less than 1:", n)
		}
		r := rp(*attribute.EmptySet())

		var wg sync.WaitGroup

		const goroutines = 2

		// Call Offer concurrently with another Offer, and with Collect.
		for i := range goroutines {
			wg.Add(1)
			go func(iteration int) {
				ctx, ts, val, attrs := generateOfferInputs[N](iteration + 1)
				r.Offer(ctx, ts, val, attrs)
				wg.Done()
			}(i)
		}

		// Also test concurrent Collect calls
		wg.Add(1)
		go func() {
			var dest []Exemplar
			r.Collect(&dest)
			wg.Done()
		}()

		wg.Wait()

		// Final collect to validate state
		var dest []Exemplar
		r.Collect(&dest)
		assert.NotEmpty(t, dest)
		for _, e := range dest {
			validateExemplar[N](t, e)
		}
	}
}

func generateOfferInputs[N int64 | float64](
	i int,
) (context.Context, time.Time, Value, []attribute.KeyValue) {
	sc := trace.NewSpanContext(trace.SpanContextConfig{
		TraceID:    trace.TraceID([16]byte{byte(i)}),
		SpanID:     trace.SpanID([8]byte{byte(i)}),
		TraceFlags: trace.FlagsSampled,
	})
	ctx := trace.ContextWithSpanContext(context.Background(), sc)
	ts := time.Unix(int64(i), int64(i))
	val := NewValue(N(i))
	attrs := []attribute.KeyValue{attribute.Int("i", i)}
	return ctx, ts, val, attrs
}

func validateExemplar[N int64 | float64](t *testing.T, e Exemplar) {
	t.Helper()
	i := 0
	switch e.Value.Type() {
	case Int64ValueType:
		i = int(e.Value.Int64())
	case Float64ValueType:
		i = int(e.Value.Float64())
	default:
		t.Fatalf("unexpected value type: %v", e.Value.Type())
	}
	if i == 0 {
		t.Fatal("empty exemplar")
	}
	ctx, ts, _, attrs := generateOfferInputs[N](i)
	sc := trace.SpanContextFromContext(ctx)
	tID := sc.TraceID()
	sID := sc.SpanID()
	assert.Equal(t, tID[:], e.TraceID)
	assert.Equal(t, sID[:], e.SpanID)
	assert.Equal(t, ts, e.Time)
	assert.Equal(t, attrs, e.FilteredAttributes)
}
