// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

// This file implements Protected Resource Metadata.
// See https://www.rfc-editor.org/rfc/rfc9728.html.

//go:build mcp_go_client_oauth

package oauthex

import (
	"context"
	"errors"
	"fmt"
	"net/http"
	"net/url"
	"path"
	"strings"
	"unicode"

	"github.com/modelcontextprotocol/go-sdk/internal/util"
)

const defaultProtectedResourceMetadataURI = "/.well-known/oauth-protected-resource"

// GetProtectedResourceMetadataFromID issues a GET request to retrieve protected resource
// metadata from a resource server by its ID.
// The resource ID is an HTTPS URL, typically with a host:port and possibly a path.
// For example:
//
//	https://example.com/server
//
// This function, following the spec (§3), inserts the default well-known path into the
// URL. In our example, the result would be
//
//	https://example.com/.well-known/oauth-protected-resource/server
//
// It then retrieves the metadata at that location using the given client (or the
// default client if nil) and validates its resource field against resourceID.
func GetProtectedResourceMetadataFromID(ctx context.Context, resourceID string, c *http.Client) (_ *ProtectedResourceMetadata, err error) {
	defer util.Wrapf(&err, "GetProtectedResourceMetadataFromID(%q)", resourceID)

	u, err := url.Parse(resourceID)
	if err != nil {
		return nil, err
	}
	// Insert well-known URI into URL.
	u.Path = path.Join(defaultProtectedResourceMetadataURI, u.Path)
	return getPRM(ctx, u.String(), c, resourceID)
}

// GetProtectedResourceMetadataFromHeader retrieves protected resource metadata
// using information in the given header, using the given client (or the default
// client if nil).
// It issues a GET request to a URL discovered by parsing the WWW-Authenticate headers in the given request.
// Per RFC 9728 section 3.3, it validates that the resource field of the resulting metadata
// matches the serverURL (the URL that the client used to make the original request to the resource server).
// If there is no metadata URL in the header, it returns nil, nil.
func GetProtectedResourceMetadataFromHeader(ctx context.Context, serverURL string, header http.Header, c *http.Client) (_ *ProtectedResourceMetadata, err error) {
	defer util.Wrapf(&err, "GetProtectedResourceMetadataFromHeader")
	headers := header[http.CanonicalHeaderKey("WWW-Authenticate")]
	if len(headers) == 0 {
		return nil, nil
	}
	cs, err := ParseWWWAuthenticate(headers)
	if err != nil {
		return nil, err
	}
	metadataURL := ResourceMetadataURL(cs)
	if metadataURL == "" {
		return nil, nil
	}
	return getPRM(ctx, metadataURL, c, serverURL)
}

// getPRM makes a GET request to the given URL, and validates the response.
// As part of the validation, it compares the returned resource field to wantResource.
func getPRM(ctx context.Context, purl string, c *http.Client, wantResource string) (*ProtectedResourceMetadata, error) {
	if !strings.HasPrefix(strings.ToUpper(purl), "HTTPS://") {
		return nil, fmt.Errorf("resource URL %q does not use HTTPS", purl)
	}
	prm, err := getJSON[ProtectedResourceMetadata](ctx, c, purl, 1<<20)
	if err != nil {
		return nil, err
	}
	// Validate the Resource field (see RFC 9728, section 3.3).
	if prm.Resource != wantResource {
		return nil, fmt.Errorf("got metadata resource %q, want %q", prm.Resource, wantResource)
	}
	// Validate the authorization server URLs to prevent XSS attacks (see #526).
	for _, u := range prm.AuthorizationServers {
		if err := checkURLScheme(u); err != nil {
			return nil, err
		}
	}
	return prm, nil
}

// challenge represents a single authentication challenge from a WWW-Authenticate header.
// As per RFC 9110, Section 11.6.1, a challenge consists of a scheme and optional parameters.
type challenge struct {
	// GENERATED BY GEMINI 2.5.
	//
	// Scheme is the authentication scheme (e.g., "Bearer", "Basic").
	// It is case-insensitive. A parsed value will always be lower-case.
	Scheme string
	// Params is a map of authentication parameters.
	// Keys are case-insensitive. Parsed keys are always lower-case.
	Params map[string]string
}

// ResourceMetadataURL returns a resource metadata URL from the given challenges,
// or the empty string if there is none.
func ResourceMetadataURL(cs []challenge) string {
	for _, c := range cs {
		if u := c.Params["resource_metadata"]; u != "" {
			return u
		}
	}
	return ""
}

// ParseWWWAuthenticate parses a WWW-Authenticate header string.
// The header format is defined in RFC 9110, Section 11.6.1, and can contain
// one or more challenges, separated by commas.
// It returns a slice of challenges or an error if one of the headers is malformed.
func ParseWWWAuthenticate(headers []string) ([]challenge, error) {
	// GENERATED BY GEMINI 2.5 (human-tweaked)
	var challenges []challenge
	for _, h := range headers {
		challengeStrings, err := splitChallenges(h)
		if err != nil {
			return nil, err
		}
		for _, cs := range challengeStrings {
			if strings.TrimSpace(cs) == "" {
				continue
			}
			challenge, err := parseSingleChallenge(cs)
			if err != nil {
				return nil, fmt.Errorf("failed to parse challenge %q: %w", cs, err)
			}
			challenges = append(challenges, challenge)
		}
	}
	return challenges, nil
}

// splitChallenges splits a header value containing one or more challenges.
// It correctly handles commas within quoted strings and distinguishes between
// commas separating auth-params and commas separating challenges.
func splitChallenges(header string) ([]string, error) {
	// GENERATED BY GEMINI 2.5.
	var challenges []string
	inQuotes := false
	start := 0
	for i, r := range header {
		if r == '"' {
			if i > 0 && header[i-1] != '\\' {
				inQuotes = !inQuotes
			} else if i == 0 {
				// A challenge begins with an auth-scheme, which is a token, which cannot contain
				// a quote.
				return nil, errors.New(`challenge begins with '"'`)
			}
		} else if r == ',' && !inQuotes {
			// This is a potential challenge separator.
			// A new challenge does not start with `key=value`.
			// We check if the part after the comma looks like a parameter.
			lookahead := strings.TrimSpace(header[i+1:])
			eqPos := strings.Index(lookahead, "=")

			isParam := false
			if eqPos > 0 {
				// Check if the part before '=' is a single token (no spaces).
				token := lookahead[:eqPos]
				if strings.IndexFunc(token, unicode.IsSpace) == -1 {
					isParam = true
				}
			}

			if !isParam {
				// The part after the comma does not look like a parameter,
				// so this comma separates challenges.
				challenges = append(challenges, header[start:i])
				start = i + 1
			}
		}
	}
	// Add the last (or only) challenge to the list.
	challenges = append(challenges, header[start:])
	return challenges, nil
}

// parseSingleChallenge parses a string containing exactly one challenge.
// challenge   = auth-scheme [ 1*SP ( token68 / #auth-param ) ]
func parseSingleChallenge(s string) (challenge, error) {
	// GENERATED BY GEMINI 2.5, human-tweaked.
	s = strings.TrimSpace(s)
	if s == "" {
		return challenge{}, errors.New("empty challenge string")
	}

	scheme, paramsStr, found := strings.Cut(s, " ")
	c := challenge{Scheme: strings.ToLower(scheme)}
	if !found {
		return c, nil
	}

	params := make(map[string]string)

	// Parse the key-value parameters.
	for paramsStr != "" {
		// Find the end of the parameter key.
		keyEnd := strings.Index(paramsStr, "=")
		if keyEnd <= 0 {
			return challenge{}, fmt.Errorf("malformed auth parameter: expected key=value, but got %q", paramsStr)
		}
		key := strings.TrimSpace(paramsStr[:keyEnd])

		// Move the string past the key and the '='.
		paramsStr = strings.TrimSpace(paramsStr[keyEnd+1:])

		var value string
		if strings.HasPrefix(paramsStr, "\"") {
			// The value is a quoted string.
			paramsStr = paramsStr[1:] // Consume the opening quote.
			var valBuilder strings.Builder
			i := 0
			for ; i < len(paramsStr); i++ {
				// Handle escaped characters.
				if paramsStr[i] == '\\' && i+1 < len(paramsStr) {
					valBuilder.WriteByte(paramsStr[i+1])
					i++ // We've consumed two characters.
				} else if paramsStr[i] == '"' {
					// End of the quoted string.
					break
				} else {
					valBuilder.WriteByte(paramsStr[i])
				}
			}

			// A quoted string must be terminated.
			if i == len(paramsStr) {
				return challenge{}, fmt.Errorf("unterminated quoted string in auth parameter")
			}

			value = valBuilder.String()
			// Move the string past the value and the closing quote.
			paramsStr = strings.TrimSpace(paramsStr[i+1:])
		} else {
			// The value is a token. It ends at the next comma or the end of the string.
			commaPos := strings.Index(paramsStr, ",")
			if commaPos == -1 {
				value = paramsStr
				paramsStr = ""
			} else {
				value = strings.TrimSpace(paramsStr[:commaPos])
				paramsStr = strings.TrimSpace(paramsStr[commaPos:]) // Keep comma for next check
			}
		}
		if value == "" {
			return challenge{}, fmt.Errorf("no value for auth param %q", key)
		}

		// Per RFC 9110, parameter keys are case-insensitive.
		params[strings.ToLower(key)] = value

		// If there is a comma, consume it and continue to the next parameter.
		if strings.HasPrefix(paramsStr, ",") {
			paramsStr = strings.TrimSpace(paramsStr[1:])
		} else if paramsStr != "" {
			// If there's content but it's not a new parameter, the format is wrong.
			return challenge{}, fmt.Errorf("malformed auth parameter: expected comma after value, but got %q", paramsStr)
		}
	}

	// Per RFC 9110, the scheme is case-insensitive.
	return challenge{Scheme: strings.ToLower(scheme), Params: params}, nil
}
