import queue import copy import re import threading #渗透测试树结构维护类 class AttackTree: def __init__(self,root_node): #针对根节点处理 self.root = root_node self.root.path = f"目标系统->{root_node.name}" def set_root(self,root_node): self.root = root_node def traverse_bfs(self): """广度优先遍历""" if not self.root: return [] queue = [self.root] result = [] while queue: current = queue.pop(0) result.append(current) queue.extend(current.children) return result def traverse_dfs(self, node=None, result=None): """深度优先遍历(前序遍历)""" if result is None: result = [] if node is None: node = self.root if not node: return [] result.append(node) for child in node.children: self.traverse_dfs(child, result) return result #生成节点树字典数据 def node_to_dict(self,node): return { "node_name":node.name, "node_path":node.path, "node_status":node.status, "node_bwork":node.bwork, "node_vultype":node.vul_type, "node_vulgrade":node.vul_grade, "node_vulinfo":node.vul_info, "node_workstatus":node.get_work_status(), "children":[self.node_to_dict(child) for child in node.children] if node.children else [] } #树简化列表,用户传输到前端 def get_node_dict(self): node_dict = self.node_to_dict(self.root) #递归生成 return node_dict def find_node_by_name(self, name): """根据名称查找节点(广度优先)""" nodes = self.traverse_bfs() for node in nodes: if node.name == name: return node return None #返回需要插入测试指令的节点 def find_node_by_nodepath_parent(self,node_path,cur_node,iadd_node,commands): node_names = node_path.split('->') node_name = node_names[-1] #返回指令对应的节点名称 if node_name == cur_node.name:#当前节点 if iadd_node == 1 and len(cur_node.children)==1: #如果在当前节点下添加了一个子节点,且子节点没有添加指令,就算指令的节点路径是当前节点,也把该测试指令添加给新增的子节点 bfind = False for comd in commands: if cur_node.children[0].name in comd: bfind = True break if not bfind: #正常来说on_instruction已经补充指令了,已经不会有这种情况了,缺点 print("执行了一次强制迁移指令到子节点!") return cur_node.children[0] else:#正常应该都有指令 return cur_node else: #添加了多个节点就难对应是哪个节点了,两个解决方案:1.返回当前节点,2.提交llm重新确认指令节点 return cur_node else: if node_names[-2] == cur_node.name: #父节点是当前节点 for child_node in cur_node.children: if child_node.name == node_name: return child_node # #走到这说明没有匹配到-则新建一个节点- # newNode = TreeNode(node_name,cur_node.task_id) # cur_node.add_child(newNode,cur_node.messages) return None else: return None #约束:不处理 def find_node_by_nodepath(self,node_path): '''基于节点路径查找节点,只返回找到的第一个节点,若有节点名称路径重复的情况暂不处理''' current_node = self.root #从根节点开始 node_names = node_path.split('->') layer_num = 0 for node_name in node_names: if node_name == "目标系统": layer_num +=1 continue if node_name == current_node.name:#根节点开始 layer_num += 1 continue else: bfound = False for child_node in current_node.children: if child_node.name == node_name: #约束同一父节点下的子节点名称不能相同 current_node = child_node layer_num += 1 bfound = True break if not bfound: #如果遍历子节点都没有符合的,说明路径有问题的,不处理中间一段路径情况 return None return current_node #更新节点的bwork状态 def update_node_bwork(self,node_path): node = self.find_node_by_nodepath(node_path) if not node: return False,False if node.bwork: node.bwork = False else: node.bwork = True return True,node.bwork def find_nodes_by_status(self, status): """根据状态查找所有匹配节点""" return [node for node in self.traverse_bfs() if node.status == status] def find_nodes_by_vul_type(self, vul_type): """根据漏洞类型查找所有匹配节点""" return [node for node in self.traverse_bfs() if node.vul_type == vul_type] #考虑要不要用tree封装节点的操作--待定 def update_node_status(self, node_name, new_status): """修改节点状态""" node = self.find_node_by_name(node_name) if node: node.status = new_status return True return False def update_node_vul_type(self,node_name,vul_type): """修改节点漏洞类型""" node = self.find_node_by_name(node_name) if node: node.vul_type = vul_type return True return False def print_tree(self, node=None, level=0): """可视化打印树结构""" if node is None: node = self.root prefix = " " * level + "|-- " if level > 0 else "" print(f"{prefix}{node.name} [{node.status}, {node.vul_type}]") for child in node.children: self.print_tree(child, level + 1) class TreeNode: def __init__(self, name,task_id,status="未完成", vul_type="未发现"): self.task_id = task_id #任务id self.name = name # 节点名称 #self.node_lock = threading.Lock() #线程锁 self.bwork = True # 当前节点是否工作,默认True --停止/启动 self.status = status # 节点测试状态 -- 由llm返回指令触发更新 #work_status需要跟两个list统一管理:初始0,入instr_queue为1,入instr_node_mq为2,入res_queue为3,入llm_node_mq为4,llm处理完0或1 self._work_status = 0 #0-无任务,1-待执行测试指令,2-执行指令中,3-待提交Llm,4-提交llm中, 2025-4-6新增,用来动态显示节点的工作细节。 #self.work_status_lock = threading.Lock() ---节点不能有锁 self.vul_type = vul_type # 漏洞类型--目前赋值时没拆json self.vul_name = "" self.vul_grade = "" self.vul_info = "" self.children = [] # 子节点列表 self.parent = None # 父节点引用 self.path = "" #当前节点的路径 self.parent_messages = [] #2024-4-23调整messages保存和传递策略,分两块,parent_messages是保留以前节点的messages self.cur_messages = [] # 针对当前节点积累的messages -- 针对不同节点提交不同的messages messages保留当前节点执行的messages self.llm_type = 0 #llm提交类型 0--初始状态无任务状态,1--指令结果反馈,2--llm错误反馈 self.llm_sn = 0 #针对该节点llm提交次数 self._llm_quere = [] #待提交llm的数据 self.do_sn = 0 #针对该节点instr执行次数 self._instr_queue = [] #针对当前节点的待执行指令----重要约束:一个节点只能有一个线程在执行指令 self.his_instr = [] #保留执行指令的记录{“instr”:***,"result":***} #用户补充信息 self.cookie = "" self.ext_info = "" #设置用户信息 def set_user_info(self,cookie,ext_info): self.cookie = cookie self.ext_info = ext_info def copy_messages(self,p_msg,c_msg): #2025-4-23修改用做给本节点加msg ''' 当前节点添加mesg,约束:p_msg除system只取两层,c_msg:只取最后两个 :param p_msg: 传递指令过来的节点的lis-msg --取传递过来的节点和parent :param c_msg: 传递指令过来的节点的当前节点-msg :return: ''' if not p_msg or not c_msg: print("Messages存储存在问题!需要立即检查逻辑!") return tmp_pmsg = copy.deepcopy(p_msg) tmp_cmsg = copy.deepcopy(c_msg) if not self.parent.parent: #正常来说self.parent肯定会有的,root节点不会触发copy_messages #如果路径不超过两级,pmsg都保留 self.parent_messages = tmp_pmsg else: #如果超过两级了,就需要判断了 pp_node_name = self.parent.parent.name bfind = False for msg in tmp_pmsg: if msg["role"] == "system": self.parent_messages.append(msg) else: if not bfind: if msg["role"] == "user": content = msg["content"] if pp_node_name in content: #节点名称在内容中就代表跟该节点相关 bfind = True self.parent_messages.append(msg) else:#当二级父节点出行后,后面数据应该都是二级内的数据 self.parent_messages.append(msg) #cmsg --取最后两轮 if len(tmp_cmsg) <=4: #cmsg全收 self.parent_messages.extend(tmp_cmsg) else: isart = len(tmp_cmsg) - 4 #正常应该都是两个两个 if isart % 2 != 0: print("c_msg数量不对称,需要检查逻辑!") for msg in tmp_cmsg[isart:]: self.parent_messages.append(msg) #添加子节点 def add_child(self, child_node): child_node.parent = self child_node.path = self.path + f"->{child_node.name}" #子节点的路径赋值 self.children.append(child_node) #修改节点的执行状态--return bchange def update_work_status(self,work_status): bsuccess = False if self._work_status != work_status: self._work_status = work_status bsuccess = True return bsuccess def get_work_status(self): #加锁有没有意义---web端和本身的工作线程会有同步问题,但与持久化相比,暂时忽略 work_status = self._work_status return work_status #-------后期扩充逻辑,目前wokr_status的修改交给上层类对象------- def add_instr(self,instr,p_msg,c_msg): if not self.parent_messages: #为空时赋值 self.copy_messages(p_msg,c_msg) self._instr_queue.append(instr) def test_add_instr(self,instr): self._instr_queue.append(instr) self._llm_quere = [] def get_instr(self): return self._instr_queue.pop(0) if self._instr_queue else None def get_instr_user(self): return self._instr_queue def del_instr(self,instr): if instr in self._instr_queue: self._instr_queue.remove(instr) #指令删除后要判断是否清空指令了 if not self._instr_queue: self._work_status = 0 #状态调整为没有带执行指令 return True,"" else: return False,"该指令不在队列中!" def add_res(self,str_res): #结构化结果字串 self._llm_quere.append(str_res) def get_res(self): return self._llm_quere.pop(0) if self._llm_quere else None def get_res_user(self): return self._llm_quere def get_work_status(self): return self._work_status def updatemsg(self,newtype,newcontent,index): if self._llm_quere:# oldmsg_llm_type = self._llm_quere[0]["llm_type"] #llm_type不修改,还未验证 newmsg = {"llm_type": int(oldmsg_llm_type), "result": newcontent} self._llm_quere[0] = newmsg else:#新增消息 newmsg = {"llm_type": int(newtype), "result": newcontent} self._llm_quere.append(newmsg) #更新节点状态 self._work_status = 3 #待提交 return True,"" def is_instr_empty(self): if self._instr_queue: return False return True def is_llm_empty(self): if self._llm_quere: return False return True def __repr__(self): return f"TreeNode({self.name}, {self.status}, {self.vul_type})" if __name__ == "__main__": pass