You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
233 lines
8.5 KiB
233 lines
8.5 KiB
from myutils.ConfigManager import myCongif
|
|
from mycode.TaskObject import TaskObject
|
|
from mycode.DBManager import app_DBM
|
|
from myutils.PickleManager import g_PKM
|
|
|
|
import threading
|
|
|
|
class TaskManager:
|
|
def __init__(self):
|
|
self.tasks = {} # 执行中的任务,test_id为key
|
|
self.num_threads = myCongif.get_data("Task_max_threads")
|
|
#获取系统信息 -- 用户可修改的都放在DB中,不修改的放config
|
|
data = app_DBM.get_system_info()
|
|
self.local_ip = data[0]
|
|
self.version = data[1]
|
|
self.tasks_lock = threading.Lock() #加个线程锁?不使用1,quart要使用的异步锁,2.限制只允许一个用户登录,3.遍历到删除的问题不大
|
|
self.web_cur_task = 0 #web端当前显示的
|
|
|
|
|
|
#判断目标是不是在当前执行任务中,---没加锁,最多跟删除有冲突,问题应该不大
|
|
def is_target_in_tasks(self,task_target):
|
|
for task in self.tasks.values():
|
|
if task_target == task.target:
|
|
return True
|
|
return False
|
|
|
|
#程序启动后,加载未完成的测试任务
|
|
def load_tasks(self):
|
|
'''程序启动时,加载未执行完成的任务'''
|
|
datas = app_DBM.get_run_tasks()
|
|
for data in datas:
|
|
task_id = data[0]
|
|
task_target = data[1]
|
|
task_status = data[2]
|
|
work_type = data[3]
|
|
cookie_info = data[4]
|
|
llm_type = data[5]
|
|
# 创建任务对象
|
|
task = TaskObject(task_target, cookie_info, work_type, llm_type, self.num_threads, self.local_ip,self)
|
|
#读取attact_tree
|
|
attack_tree = g_PKM.ReadData(str(task_id))
|
|
#开始任务 ---会根据task_status来判断是否需要启动工作线程
|
|
task.start_task(task_id,task_status,attack_tree)
|
|
# 保留task对象
|
|
self.tasks[task_id] = task
|
|
|
|
#新建测试任务
|
|
def create_task(self, test_target,cookie_info,llm_type,work_type):
|
|
"""创建新任务--create和load复用?--
|
|
1.创建和初始化task_object;
|
|
2.创建task_id
|
|
3.启动工作线程
|
|
return T/F
|
|
"""
|
|
if self.is_target_in_tasks(test_target):
|
|
raise ValueError(f"Task {test_target} already exists")
|
|
#创建任务对象
|
|
task = TaskObject(test_target,cookie_info,work_type,llm_type,self.num_threads,self.local_ip,self)
|
|
#获取task_id -- test_target,cookie_info,work_type,llm_type 入数据库
|
|
task_id = app_DBM.start_task(test_target,cookie_info,work_type,llm_type)
|
|
if task_id >0:
|
|
#创建后启动工作--同时赋值task_id
|
|
task.start_task(task_id)
|
|
#保留task对象
|
|
self.tasks[task_id] = task
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def over_task(self,task_id):
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
task.brun = False
|
|
#修改数据库数据
|
|
bsuccess = app_DBM.over_task(task_id)
|
|
if bsuccess:
|
|
del self.tasks[task_id] #删除缓存
|
|
return bsuccess,""
|
|
else:
|
|
return False,"没有找到对应的任务"
|
|
|
|
def del_task(self,task_id):
|
|
if g_PKM.DelData(str(task_id)):
|
|
bsuccess = app_DBM.del_task(task_id)
|
|
return bsuccess,""
|
|
else:
|
|
return False,"删除对应文件失败"
|
|
|
|
#控制task启停----线程不停
|
|
def control_taks(self,task_id):
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
if task.task_status == 0: # 0-暂停,1-执行中,2-已完成
|
|
task.task_status = 1
|
|
elif task.task_status == 1:
|
|
task.task_status = 0
|
|
else:
|
|
return False,"当前任务状态不允许修改,请联系管理员!",task.task_status
|
|
else:
|
|
return False,"没有找到对应的任务",None
|
|
return True,"",task.task_status
|
|
|
|
# 获取任务list
|
|
def get_task_list(self):
|
|
tasks = []
|
|
for task in self.tasks.values():
|
|
one_json = {"taskID": task.task_id, "testTarget": task.target, "taskStatus": task.task_status, "safeRank": task.safe_rank,
|
|
"workType": task.work_type}
|
|
tasks.append(one_json)
|
|
rejson = {"tasks": tasks}
|
|
return rejson
|
|
|
|
#获取节点树
|
|
def get_node_tree(self,task_id):
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
self.web_cur_task = task_id
|
|
tree_dict = task.attack_tree.get_node_dict()
|
|
return tree_dict
|
|
return None
|
|
|
|
#获取历史节点树数据
|
|
def get_his_node_tree(self,task_id):
|
|
attack_tree = g_PKM.ReadData(str(task_id))
|
|
if attack_tree:
|
|
tree_dict = attack_tree.get_node_dict()
|
|
return tree_dict
|
|
return None
|
|
|
|
#修改任务的工作模式,只有在暂停状态才能修改
|
|
def update_task_work_type(self,task_id,new_work_type):
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
if task.task_status == 0:
|
|
task.work_type = new_work_type
|
|
return True
|
|
return False
|
|
|
|
#控制节点的工作状态
|
|
def node_bwork_control(self,task_id,node_path):
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
bsuccess,new_bwork = task.attack_tree.update_node_bwork(node_path)
|
|
if bsuccess:
|
|
pass #是否要更新IO数据?----待验证是否有只修改部分数据的方案
|
|
return bsuccess,new_bwork
|
|
return False,False
|
|
|
|
#节点单步--只允許web端调用
|
|
async def node_one_step(self,task_id,node_path):
|
|
task = self.tasks[task_id]
|
|
node = task.attack_tree.find_node_by_nodepath(node_path)
|
|
#web端触发的任务,需要判断当前的执行状态
|
|
bsuccess,error = await task.put_one_node(node)
|
|
return bsuccess,error
|
|
|
|
#任务单点--只允许web端调用
|
|
async def task_one_step(self,task_id):
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
bsuccess,error = await task.put_one_task()
|
|
return bsuccess,error
|
|
else:
|
|
return False,"task_id值存在问题!"
|
|
|
|
#获取节点待执行任务
|
|
def get_task_node_todo_instr(self,task_id,nodepath):
|
|
todoinstr = []
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
node = task.attack_tree.find_node_by_nodepath(nodepath)
|
|
if node:
|
|
todoinstr = node.get_instr_user()
|
|
return todoinstr
|
|
|
|
#获取节点的MSG信息
|
|
def get_task_node_MSG(self,task_id,nodepath):
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
node = task.attack_tree.find_node_by_nodepath(nodepath)
|
|
if node:
|
|
tmpMsg = node.get_res_user()
|
|
if tmpMsg:
|
|
return node.messages,tmpMsg[0] #待提交消息正常应该只有一条
|
|
else:
|
|
return node.messages,{}
|
|
return [],{}
|
|
|
|
def update_node_MSG(self,task_id,nodepath,newtype,newcontent):
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
node = task.attack_tree.find_node_by_nodepath(nodepath)
|
|
if node:
|
|
work_status = node.get_work_status()
|
|
if work_status == 0 or work_status == 3:
|
|
bsuccess,error = node.updatemsg(newtype,newcontent,0) #取的第一条,也就修改第一条
|
|
return bsuccess,error
|
|
else:
|
|
return False,"当前节点的工作状态不允许修改MSG!"
|
|
return False,"找不到对应节点!"
|
|
|
|
def del_node_instr(self,task_id,nodepath,instr):
|
|
task = self.tasks[task_id]
|
|
if task:
|
|
node = task.attack_tree.find_node_by_nodepath(nodepath)
|
|
if node:
|
|
return node.del_instr(instr)
|
|
return False,"找不到对应节点!"
|
|
|
|
def get_his_tasks(self,target_name,safe_rank,llm_type,start_time,end_time):
|
|
tasks = app_DBM.get_his_tasks(target_name,safe_rank,llm_type,start_time,end_time)
|
|
return tasks
|
|
|
|
|
|
#------------以下函数还未验证处理-----------
|
|
|
|
def start_task(self, task_id):
|
|
"""启动指定任务"""
|
|
task = self.tasks.get(task_id)
|
|
if task:
|
|
task.start(self.num_threads)
|
|
else:
|
|
print(f"Task {task_id} not found")
|
|
|
|
def stop_task(self, task_id):
|
|
"""停止指定任务"""
|
|
task = self.tasks.get(task_id)
|
|
if task:
|
|
task.stop()
|
|
else:
|
|
print(f"Task {task_id} not found")
|
|
|
|
g_TaskM = TaskManager() #单一实例
|