Files
Sigma-C2/server/task_handler.go

268 lines
6.5 KiB
Go
Raw Normal View History

2025-02-06 14:42:06 +01:00
package main
import (
"fmt"
"log"
"math/rand"
2025-02-06 14:42:06 +01:00
"net"
"sync"
"time"
2025-02-06 14:42:06 +01:00
)
2025-04-26 21:11:19 +02:00
type TaskHandler struct {
// Using sync.Map instead of a regular map with mutex
agentTasks sync.Map // Maps AgentID to a queue of tasks ([]Task)
}
var taskHandler = NewTaskHandler()
// Characters for generating random task IDs
var taskIDCharset = "0123456789"
var taskIDLength = 6
2025-02-06 14:42:06 +01:00
type Task struct {
TaskID string
2025-02-06 14:42:06 +01:00
AgentID string
Type string
Args string
Payload []byte
OperatorConn net.Conn
OperatorID string
Dispatched bool // Tracks whether the task has been dispatched
2025-02-06 14:42:06 +01:00
}
2025-04-26 21:11:19 +02:00
// Initialize the random number generator with a seed
func init() {
rand.New(rand.NewSource(time.Now().UnixNano()))
2025-02-06 14:42:06 +01:00
}
2025-04-26 21:11:19 +02:00
// generateTaskID creates a random 6-character alphanumeric task ID (lowercase)
func generateTaskID() string {
b := make([]byte, taskIDLength)
for i := range b {
b[i] = taskIDCharset[rand.Intn(len(taskIDCharset))]
}
return string(b)
}
2025-02-06 14:42:06 +01:00
// NewAgentHandler initializes a new TaskHandler instance.
func NewTaskHandler() *TaskHandler {
return &TaskHandler{}
2025-02-06 14:42:06 +01:00
}
// QueueTask queues a task for a specific agent.
func (th *TaskHandler) QueueTask(agentID string, operatorID string, taskType string, taskArgs string, moduleData []byte) string {
2025-02-06 14:42:06 +01:00
var logMessage string
var taskID string
2025-02-06 14:42:06 +01:00
// Check if agent exists
if _, ok := agents.Load(agentID); ok {
// Generate a unique task ID
taskID = generateTaskID()
2025-02-06 14:42:06 +01:00
task := Task{
TaskID: taskID,
2025-02-06 14:42:06 +01:00
AgentID: agentID,
Type: taskType,
Args: taskArgs,
Payload: moduleData,
OperatorID: operatorID,
Dispatched: false,
2025-02-06 14:42:06 +01:00
}
// Get the current task queue or create a new one
var tasks []Task
if existingTasks, ok := th.agentTasks.Load(agentID); ok {
tasks = existingTasks.([]Task)
}
// Append the new task and store back in the sync.Map
tasks = append(tasks, task)
th.agentTasks.Store(agentID, tasks)
logMessage = fmt.Sprintf("Queued task #%s for agent %s: %s", taskID, agentID, taskType)
2025-02-06 14:42:06 +01:00
} else {
logMessage = fmt.Sprintf("Agent not found: %s", agentID)
}
// Log logMessage
log.Print(logMessage)
2025-02-06 14:42:06 +01:00
return taskID // Return the generated task ID
2025-02-06 14:42:06 +01:00
}
// GetNextTask retrieves the next undispatched task for an agent, if any.
2025-02-06 14:42:06 +01:00
func (th *TaskHandler) GetNextTask(agentID string) *Task {
tasksInterface, exists := th.agentTasks.Load(agentID)
if !exists {
return nil
}
2025-02-06 14:42:06 +01:00
tasks := tasksInterface.([]Task)
if len(tasks) == 0 {
2025-02-06 14:42:06 +01:00
return nil
}
// Find the first undispatched task
for i := range tasks {
if !tasks[i].Dispatched {
// Mark the task as dispatched
tasks[i].Dispatched = true
// Update the tasks in storage with the modified task
th.agentTasks.Store(agentID, tasks)
// Log that the task is being executed
log.Printf("Agent %s fetched task %s: %s", agentID, tasks[i].TaskID, tasks[i].Type)
return &tasks[i]
}
}
// No undispatched tasks found
return nil
2025-02-06 14:42:06 +01:00
}
// ListTasks lists all queued tasks for a specific agent.
func (th *TaskHandler) ListTasks(agentID string) []Task {
tasksInterface, exists := th.agentTasks.Load(agentID)
if !exists {
return []Task{} // Return empty slice if no tasks exist
}
2025-02-06 14:42:06 +01:00
return tasksInterface.([]Task)
2025-02-06 14:42:06 +01:00
}
// ClearTasks clears all tasks for a specific agent.
func (th *TaskHandler) ClearTasks(agentID string) {
th.agentTasks.Delete(agentID)
2025-02-06 14:42:06 +01:00
// Log message
2025-02-06 14:42:06 +01:00
logMessage := fmt.Sprintf("Cleared tasks for agent %s", agentID)
log.Print(logMessage)
2025-02-06 14:42:06 +01:00
// Notify operators
SendMessageToAllOperators(logMessage)
}
// GetTaskByID finds a task by its ID for a given agent
func (th *TaskHandler) GetTaskByID(agentID string, taskID string) (*Task, bool) {
tasksInterface, exists := th.agentTasks.Load(agentID)
if !exists {
return nil, false
}
tasks := tasksInterface.([]Task)
for i, task := range tasks {
if task.TaskID == taskID {
return &tasks[i], true
}
}
return nil, false
}
// RemoveTaskByID removes a specific task by its ID
func (th *TaskHandler) RemoveTaskByID(agentID string, taskID string) bool {
tasksInterface, exists := th.agentTasks.Load(agentID)
if !exists {
return false
}
tasks := tasksInterface.([]Task)
for i, task := range tasks {
if task.TaskID == taskID {
// Remove the task by slicing
updatedTasks := append(tasks[:i], tasks[i+1:]...)
// If there are still tasks, update the stored slice
if len(updatedTasks) > 0 {
th.agentTasks.Store(agentID, updatedTasks)
} else {
// If no tasks remain, delete the entry
th.agentTasks.Delete(agentID)
}
log.Printf("Removed task %s for agent %s", taskID, agentID)
return true
}
}
return false
}
// GetOperatorConnByTaskID retrieves the operator connection associated with a task
func (th *TaskHandler) GetOperatorConnByTaskID(agentID string, taskID string) (net.Conn, bool) {
task, found := th.GetTaskByID(agentID, taskID)
if !found {
return nil, false
}
// Get the operator connection
operatorConn, exists := GetOperatorConn(task.OperatorID)
return operatorConn, exists
}
// GetOperatorIDByTaskID retrieves the operator ID associated with a task
func (th *TaskHandler) GetOperatorIDByTaskID(agentID string, taskID string) (string, bool) {
task, found := th.GetTaskByID(agentID, taskID)
if !found {
return "", false
}
return task.OperatorID, true
}
// MarkTaskComplete marks a task as complete and optionally removes it
func (th *TaskHandler) MarkTaskComplete(agentID string, taskID string, removeTask bool) bool {
tasksInterface, exists := th.agentTasks.Load(agentID)
if !exists {
return false
}
tasks := tasksInterface.([]Task)
for i, task := range tasks {
if task.TaskID == taskID {
if removeTask {
// Remove the task by slicing
updatedTasks := append(tasks[:i], tasks[i+1:]...)
// If there are still tasks, update the stored slice
if len(updatedTasks) > 0 {
th.agentTasks.Store(agentID, updatedTasks)
} else {
// If no tasks remain, delete the entry
th.agentTasks.Delete(agentID)
}
log.Printf("Task %s for agent %s marked as complete and removed", taskID, agentID)
} else {
// Just update the task status if we're keeping it
log.Printf("Task %s for agent %s marked as complete", taskID, agentID)
}
return true
}
}
return false
}
// GetTaskCount returns the count of all tasks and undispatched tasks for an agent
func (th *TaskHandler) GetTaskCount(agentID string) (total int, undispatched int) {
tasksInterface, exists := th.agentTasks.Load(agentID)
if !exists {
return 0, 0
}
tasks := tasksInterface.([]Task)
total = len(tasks)
for _, task := range tasks {
if !task.Dispatched {
undispatched++
}
}
return total, undispatched
}