package main import ( "fmt" "log" "math/rand" "net" "sync" "time" ) type TaskHandler struct { // Using sync.Map instead of a regular map with mutex agentTasks sync.Map // Maps AgentID to a queue of tasks ([]Task) } var taskHandler = NewTaskHandler() // Characters for generating random task IDs var taskIDCharset = "0123456789" var taskIDLength = 6 type Task struct { TaskID string AgentID string Type string Args string Arg1 string Arg2 string Arg3 string Payload []byte OperatorConn net.Conn OperatorID string Dispatched bool // Tracks whether the task has been dispatched } // Initialize the random number generator with a seed func init() { rand.New(rand.NewSource(time.Now().UnixNano())) } // generateTaskID creates a random 6-character alphanumeric task ID (lowercase) func generateTaskID() string { b := make([]byte, taskIDLength) for i := range b { b[i] = taskIDCharset[rand.Intn(len(taskIDCharset))] } return string(b) } // NewAgentHandler initializes a new TaskHandler instance. func NewTaskHandler() *TaskHandler { return &TaskHandler{} } // QueueTask queues a task for a specific agent. func (th *TaskHandler) QueueTask(agentID string, operatorID string, taskType string, taskArgs string, moduleData []byte) string { var logMessage string var taskID string // Check if agent exists if _, ok := agents.Load(agentID); ok { // Generate a unique task ID taskID = generateTaskID() task := Task{ TaskID: taskID, AgentID: agentID, Type: taskType, Args: taskArgs, Payload: moduleData, OperatorID: operatorID, Dispatched: false, } // Get the current task queue or create a new one var tasks []Task if existingTasks, ok := th.agentTasks.Load(agentID); ok { tasks = existingTasks.([]Task) } // Append the new task and store back in the sync.Map tasks = append(tasks, task) th.agentTasks.Store(agentID, tasks) logMessage = fmt.Sprintf("Queued task #%s for agent %s: %s", taskID, agentID, taskType) } else { logMessage = fmt.Sprintf("Agent not found: %s", agentID) } // Log logMessage log.Print(logMessage) return taskID // Return the generated task ID } // GetNextTask retrieves the next undispatched task for an agent, if any. func (th *TaskHandler) GetNextTask(agentID string) *Task { tasksInterface, exists := th.agentTasks.Load(agentID) if !exists { return nil } tasks := tasksInterface.([]Task) if len(tasks) == 0 { return nil } // Find the first undispatched task for i := range tasks { if !tasks[i].Dispatched { // Mark the task as dispatched tasks[i].Dispatched = true // Update the tasks in storage with the modified task th.agentTasks.Store(agentID, tasks) // Log that the task is being executed log.Printf("Agent %s fetched task: %s: %s", agentID, tasks[i].TaskID, tasks[i].Type) return &tasks[i] } } // No undispatched tasks found return nil } // ListTasks lists all queued tasks for a specific agent. func (th *TaskHandler) ListTasks(agentID string) []Task { tasksInterface, exists := th.agentTasks.Load(agentID) if !exists { return []Task{} // Return empty slice if no tasks exist } return tasksInterface.([]Task) } // ClearTasks clears all tasks for a specific agent. func (th *TaskHandler) ClearTasks(agentID string) { th.agentTasks.Delete(agentID) // Log message logMessage := fmt.Sprintf("Cleared tasks for agent %s", agentID) log.Print(logMessage) // Notify operators // SendMessageToAllOperators(logMessage) } // GetTaskByID finds a task by its ID for a given agent func (th *TaskHandler) GetTaskByID(agentID string, taskID string) (*Task, bool) { tasksInterface, exists := th.agentTasks.Load(agentID) if !exists { return nil, false } tasks := tasksInterface.([]Task) for i, task := range tasks { if task.TaskID == taskID { return &tasks[i], true } } return nil, false } // RemoveTaskByID removes a specific task by its ID func (th *TaskHandler) RemoveTaskByID(agentID string, taskID string) bool { tasksInterface, exists := th.agentTasks.Load(agentID) if !exists { return false } tasks := tasksInterface.([]Task) for i, task := range tasks { if task.TaskID == taskID { // Remove the task by slicing updatedTasks := append(tasks[:i], tasks[i+1:]...) // If there are still tasks, update the stored slice if len(updatedTasks) > 0 { th.agentTasks.Store(agentID, updatedTasks) } else { // If no tasks remain, delete the entry th.agentTasks.Delete(agentID) } log.Printf("Removed task %s for agent %s", taskID, agentID) return true } } return false } // GetOperatorConnByTaskID retrieves the operator connection associated with a task func (th *TaskHandler) GetOperatorConnByTaskID(agentID string, taskID string) (net.Conn, bool) { task, found := th.GetTaskByID(agentID, taskID) if !found { return nil, false } // Get the operator connection operatorConn, exists := GetOperatorConn(task.OperatorID) return operatorConn, exists } // GetOperatorIDByTaskID retrieves the operator ID associated with a task func (th *TaskHandler) GetOperatorIDByTaskID(agentID string, taskID string) (string, bool) { task, found := th.GetTaskByID(agentID, taskID) if !found { return "", false } return task.OperatorID, true } // MarkTaskComplete marks a task as complete and optionally removes it func (th *TaskHandler) MarkTaskComplete(agentID string, taskID string, removeTask bool) bool { tasksInterface, exists := th.agentTasks.Load(agentID) if !exists { return false } tasks := tasksInterface.([]Task) for i, task := range tasks { if task.TaskID == taskID { if removeTask { // Remove the task by slicing updatedTasks := append(tasks[:i], tasks[i+1:]...) // If there are still tasks, update the stored slice if len(updatedTasks) > 0 { th.agentTasks.Store(agentID, updatedTasks) } else { // If no tasks remain, delete the entry th.agentTasks.Delete(agentID) } log.Printf("Task %s for agent %s marked as complete and removed", taskID, agentID) } else { // Just update the task status if we're keeping it log.Printf("Task %s for agent %s marked as complete", taskID, agentID) } return true } } return false } // GetTaskCount returns the count of all tasks and undispatched tasks for an agent func (th *TaskHandler) GetTaskCount(agentID string) (total int, undispatched int) { tasksInterface, exists := th.agentTasks.Load(agentID) if !exists { return 0, 0 } tasks := tasksInterface.([]Task) total = len(tasks) for _, task := range tasks { if !task.Dispatched { undispatched++ } } return total, undispatched }