import queue #渗透测试树结构维护类 class AttackTree: def __init__(self,root_node): #针对根节点处理 self.root = root_node self.root.path = f"目标系统->{root_node.name}" def set_root(self,root_node): self.root = root_node def add_node(self,parent_name,new_node): """根据父节点名称添加新节点""" parent_node = self.find_node_by_name(parent_name) if parent_node: parent_node.add_child(new_node) return True return False def traverse_bfs(self): """广度优先遍历""" if not self.root: return [] queue = [self.root] result = [] while queue: current = queue.pop(0) result.append(current) queue.extend(current.children) return result def traverse_dfs(self, node=None, result=None): """深度优先遍历(前序遍历)""" if result is None: result = [] if node is None: node = self.root if not node: return [] result.append(node) for child in node.children: self.traverse_dfs(child, result) return result def find_node_by_name(self, name): """根据名称查找节点(广度优先)""" nodes = self.traverse_bfs() for node in nodes: if node.name == name: return node return None def find_node_by_nodepath(self,node_path): '''基于节点路径查找节点,只返回找到的第一个节点,若有节点名称路径重复的情况暂不处理''' current_node = self.root #从根节点开始 node_names = node_path.split('->') for node_name in node_names: if node_name == current_node.name:#根节点开始 continue else: bfound = False for child_node in current_node.children: if child_node.name == node_name: #约束同一父节点下的子节点名称不能相同 current_node = child_node bfound = True break if not bfound: #如果遍历子节点都没有符合的,说明路径有问题的,不处理中间一段路径情况 return None #找到的话,就开始匹配下一层 return current_node def find_nodes_by_status(self, status): """根据状态查找所有匹配节点""" return [node for node in self.traverse_bfs() if node.status == status] def find_nodes_by_vul_type(self, vul_type): """根据漏洞类型查找所有匹配节点""" return [node for node in self.traverse_bfs() if node.vul_type == vul_type] #考虑要不要用tree封装节点的操作--待定 def update_node_status(self, node_name, new_status): """修改节点状态""" node = self.find_node_by_name(node_name) if node: node.status = new_status return True return False def update_node_vul_type(self,node_name,vul_type): """修改节点漏洞类型""" node = self.find_node_by_name(node_name) if node: node.vul_type = vul_type return True return False def print_tree(self, node=None, level=0): """可视化打印树结构""" if node is None: node = self.root prefix = " " * level + "|-- " if level > 0 else "" print(f"{prefix}{node.name} [{node.status}, {node.vul_type}]") for child in node.children: self.print_tree(child, level + 1) class TreeNode: def __init__(self, name, status="未完成", vul_type="未发现"): self.name = name # 节点名称 self.status = status # 节点状态 self.vul_type = vul_type # 漏洞类型 self.children = [] # 子节点列表 self.parent = None # 父节点引用 self.path = "" #当前节点的路径 self.instr_queue = queue.Queue() #针对当前节点的执行指令----重要约束:一个节点只能有一个线程在执行指令 self.res_quere = queue.Queue() #指令执行的结果,一批一批 self.llm_sn = 0 #针对该节点llm提交次数 self.do_sn = 0 #针对该节点instr执行次数 self.messages = [] #针对当前节点积累的messages -- 针对不同节点提交不同的messages def add_child(self, child_node): child_node.parent = self child_node.path = self.path + f"->{child_node.name}" #子节点的路径赋值 child_node.messages = self.messages #传递messages #给什么时候的messages待验证#? self.children.append(child_node) def add_instr(self,instr): self.instr_queue.put(instr) def get_instr(self): if self.instr_queue.empty(): return None return self.instr_queue.get() def add_res(self,str_res): #结构化结果字串 self.res_quere.put(str_res) def get_res(self): if self.res_quere.empty(): return None return self.res_quere.get() def __repr__(self): return f"TreeNode({self.name}, {self.status}, {self.vul_type})" if __name__ == "__main__": pass