You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

171 lines
6.2 KiB

import queue
#渗透测试树结构维护类
class AttackTree:
def __init__(self,root_node):
#针对根节点处理
self.root = root_node
self.root.path = f"目标系统->{root_node.name}"
def set_root(self,root_node):
self.root = root_node
def add_node(self,parent_name,new_node):
"""根据父节点名称添加新节点"""
parent_node = self.find_node_by_name(parent_name)
if parent_node:
parent_node.add_child(new_node)
return True
return False
def traverse_bfs(self):
"""广度优先遍历"""
if not self.root:
return []
queue = [self.root]
result = []
while queue:
current = queue.pop(0)
result.append(current)
queue.extend(current.children)
return result
def traverse_dfs(self, node=None, result=None):
"""深度优先遍历(前序遍历)"""
if result is None:
result = []
if node is None:
node = self.root
if not node:
return []
result.append(node)
for child in node.children:
self.traverse_dfs(child, result)
return result
def find_node_by_name(self, name):
"""根据名称查找节点(广度优先)"""
nodes = self.traverse_bfs()
for node in nodes:
if node.name == name:
return node
return None
def find_node_by_nodepath_parent(self,node_path,node):
node_names = node_path.split('->')
node_name = node_names[-1]
if node_name == node.name:#当前节点
return node
else:
if node_names[-2] == node.name: #父节点是当前节点
for child_node in node.children:
if child_node.name == node_name:
return child_node
#走到这说明没有匹配到-则新建一个节点
newNode = TreeNode(node_name)
node.add_child(newNode)
return newNode
else:
return None #约束:不处理
def find_node_by_nodepath(self,node_path):
'''基于节点路径查找节点,只返回找到的第一个节点,若有节点名称路径重复的情况暂不处理'''
current_node = self.root #从根节点开始
node_names = node_path.split('->')
layer_num = 0
for node_name in node_names:
if node_name == "目标系统":
layer_num +=1
continue
if node_name == current_node.name:#根节点开始
layer_num += 1
continue
else:
bfound = False
for child_node in current_node.children:
if child_node.name == node_name: #约束同一父节点下的子节点名称不能相同
current_node = child_node
layer_num += 1
bfound = True
break
if not bfound: #如果遍历子节点都没有符合的,说明路径有问题的,不处理中间一段路径情况
return None
return current_node
def find_nodes_by_status(self, status):
"""根据状态查找所有匹配节点"""
return [node for node in self.traverse_bfs() if node.status == status]
def find_nodes_by_vul_type(self, vul_type):
"""根据漏洞类型查找所有匹配节点"""
return [node for node in self.traverse_bfs() if node.vul_type == vul_type]
#考虑要不要用tree封装节点的操作--待定
def update_node_status(self, node_name, new_status):
"""修改节点状态"""
node = self.find_node_by_name(node_name)
if node:
node.status = new_status
return True
return False
def update_node_vul_type(self,node_name,vul_type):
"""修改节点漏洞类型"""
node = self.find_node_by_name(node_name)
if node:
node.vul_type = vul_type
return True
return False
def print_tree(self, node=None, level=0):
"""可视化打印树结构"""
if node is None:
node = self.root
prefix = " " * level + "|-- " if level > 0 else ""
print(f"{prefix}{node.name} [{node.status}, {node.vul_type}]")
for child in node.children:
self.print_tree(child, level + 1)
class TreeNode:
def __init__(self, name, status="未完成", vul_type="未发现"):
self.name = name # 节点名称
self.status = status # 节点状态
self.vul_type = vul_type # 漏洞类型
self.children = [] # 子节点列表
self.parent = None # 父节点引用
self.path = "" #当前节点的路径
self.messages = [] # 针对当前节点积累的messages -- 针对不同节点提交不同的messages
self.llm_type = 0 #llm提交类型 0--初始状态无任务状态,1--指令结果反馈,2--llm错误反馈
self.llm_sn = 0 #针对该节点llm提交次数
self.do_sn = 0 #针对该节点instr执行次数
self.instr_queue = [] # queue.Queue() #针对当前节点的执行指令----重要约束:一个节点只能有一个线程在执行指令
self.res_quere = [] # queue.Queue() #指令执行的结果,一批一批
def add_child(self, child_node):
child_node.parent = self
child_node.path = self.path + f"->{child_node.name}" #子节点的路径赋值
child_node.messages = self.messages #传递messages #给什么时候的messages待验证#?
self.children.append(child_node)
def add_instr(self,instr):
self.instr_queue.append(instr)
def get_instr(self):
return self.instr_queue.pop(0) if self.instr_queue else None
def add_res(self,str_res): #结构化结果字串
self.res_quere.append(str_res)
def get_res(self):
return self.res_queue.pop(0) if self.res_queue else None
def __repr__(self):
return f"TreeNode({self.name}, {self.status}, {self.vul_type})"
if __name__ == "__main__":
pass