import queue
import copy
import re
import threading

#渗透测试树结构维护类
class AttackTree:
    def __init__(self,root_node):
        #针对根节点处理
        self.root = root_node
        self.root.path = f"目标系统->{root_node.name}"


    def set_root(self,root_node):
        self.root = root_node

    def traverse_bfs(self):
        """广度优先遍历"""
        if not self.root:
            return []

        queue = [self.root]
        result = []
        while queue:
            current = queue.pop(0)
            result.append(current)
            queue.extend(current.children)
        return result

    def traverse_dfs(self, node=None, result=None):
        """深度优先遍历(前序遍历)"""
        if result is None:
            result = []
        if node is None:
            node = self.root
            if not node:
                return []

        result.append(node)
        for child in node.children:
            self.traverse_dfs(child, result)
        return result

    #生成节点树字典数据
    def node_to_dict(self,node):
        return {
            "node_name":node.name,
            "node_path":node.path,
            "node_status":node.status,
            "node_bwork":node.bwork,
            "node_vultype":node.vul_type,
            "node_vulgrade":node.vul_grade,
            "node_vulinfo":node.vul_info,
            "node_workstatus":node.get_work_status(),
            "children":[self.node_to_dict(child) for child in node.children] if node.children else []
        }

    #树简化列表,用户传输到前端
    def get_node_dict(self):
        node_dict = self.node_to_dict(self.root)    #递归生成
        return node_dict

    def find_node_by_name(self, name):
        """根据名称查找节点(广度优先)"""
        nodes = self.traverse_bfs()
        for node in nodes:
            if node.name == name:
                return node
        return None

    #返回需要插入测试指令的节点
    def find_node_by_nodepath_parent(self,node_path,cur_node,iadd_node,commands):
        node_names = node_path.split('->')
        node_name = node_names[-1]  #返回指令对应的节点名称
        if node_name == cur_node.name:#当前节点
            if iadd_node == 1 and len(cur_node.children)==1:  #如果在当前节点下添加了一个子节点,且子节点没有添加指令,就算指令的节点路径是当前节点,也把该测试指令添加给新增的子节点
                bfind = False
                for comd in commands:
                    if cur_node.children[0].name in comd:
                        bfind = True
                        break
                if not bfind:   #正常来说on_instruction已经补充指令了,已经不会有这种情况了,缺点
                    print("执行了一次强制迁移指令到子节点!")
                    return cur_node.children[0]
                else:#正常应该都有指令
                    return cur_node
            else:
                #添加了多个节点就难对应是哪个节点了,两个解决方案:1.返回当前节点,2.提交llm重新确认指令节点
                return cur_node
        else:
            if node_names[-2] == cur_node.name: #父节点是当前节点
                for child_node in cur_node.children:
                    if child_node.name == node_name:
                        return child_node
                # #走到这说明没有匹配到-则新建一个节点- 少个layer
                # newNode = TreeNode(node_name,cur_node.task_id)
                # cur_node.add_child(newNode,cur_node.messages)
                return None
            else:
                return None #约束:不处理

    def find_node_by_nodepath(self,node_path):
        '''基于节点路径查找节点,只返回找到的第一个节点,若有节点名称路径重复的情况暂不处理'''
        current_node = self.root #从根节点开始
        node_names = node_path.split('->')
        layer_num = 0
        for node_name in node_names:
            if node_name == "目标系统":
                layer_num +=1
                continue
            if node_name == current_node.name:#根节点开始
                layer_num += 1
                continue
            else:
                bfound = False
                for child_node in current_node.children:
                    if child_node.name == node_name:    #约束同一父节点下的子节点名称不能相同
                        current_node = child_node
                        layer_num += 1
                        bfound = True
                        break
                if not bfound:  #如果遍历子节点都没有符合的,说明路径有问题的,不处理中间一段路径情况
                    return  None
        return current_node

    #更新节点的bwork状态
    def update_node_bwork(self,node_path):
        node = self.find_node_by_nodepath(node_path)
        if not node:
            return False,False
        if node.bwork:
            node.bwork = False
        else:
            node.bwork = True
        return True,node.bwork

    def find_nodes_by_status(self, status):
        """根据状态查找所有匹配节点"""
        return [node for node in self.traverse_bfs() if node.status == status]

    def find_nodes_by_vul_type(self, vul_type):
        """根据漏洞类型查找所有匹配节点"""
        return [node for node in self.traverse_bfs() if node.vul_type == vul_type]

    #考虑要不要用tree封装节点的操作--待定
    def update_node_status(self, node_name, new_status):
        """修改节点状态"""
        node = self.find_node_by_name(node_name)
        if node:
            node.status = new_status
            return True
        return False

    def update_node_vul_type(self,node_name,vul_type):
        """修改节点漏洞类型"""
        node = self.find_node_by_name(node_name)
        if node:
            node.vul_type = vul_type
            return True
        return False

    def print_tree(self, node=None, level=0):
        """可视化打印树结构"""
        if node is None:
            node = self.root
        prefix = "    " * level + "|-- " if level > 0 else ""
        print(f"{prefix}{node.name} [{node.status}, {node.vul_type}]")
        for child in node.children:
            self.print_tree(child, level + 1)


class TreeNode:
    def __init__(self, name,task_id,node_layer,status="未完成", vul_type="未发现"):
        self.task_id = task_id  #任务id
        self.name = name  # 节点名称
        self.cur_layer = node_layer  # 节点当前层数
        self.bwork = True  # 当前节点是否工作,默认True --停止/启动
        self.status = status  # 节点测试状态 -- 由llm返回指令触发更新
        #work_status需要跟两个list统一管理:初始0,入instr_queue为1,入instr_node_mq为2,入res_queue为3,入llm_node_mq为4,llm处理完0或1
        self._work_status = 0    #0-无任务,1-待执行测试指令,2-执行指令中,3-待提交Llm,4-提交llm中, 2025-4-6新增,用来动态显示节点的工作细节。
        #self.work_status_lock = threading.Lock() ---节点不能有锁
        self.vul_type = vul_type  # 漏洞类型--目前赋值时没拆json
        self.vul_name = ""
        self.vul_grade = ""
        self.vul_info = ""
        self.children = []  # 子节点列表
        self.parent = None  # 父节点引用
        self.path = ""      #当前节点的路径
        self.parent_messages = []   #2024-4-23调整messages保存和传递策略,分两块,parent_messages是保留以前节点的messages
        self.cur_messages = []  # 针对当前节点积累的messages -- 针对不同节点提交不同的messages  messages保留当前节点执行的messages

        self.llm_type = 0   #llm提交类型 0--初始状态无任务状态,1--指令结果反馈,2--llm错误反馈
        self.llm_sn = 0     #针对该节点llm提交次数
        self._llm_quere = [] #待提交llm的数据

        self.do_sn = 0      #针对该节点instr执行次数
        self._instr_queue = []  #针对当前节点的待执行指令----重要约束:一个节点只能有一个线程在执行指令

        self.his_instr = []   #保留执行指令的记录{“instr”:***,"result":***}
        #单步相关
        self.step_num = 0   #单步执行次数
        #用户补充信息
        self.cookie = ""
        self.ext_info = ""
        #线程锁-- 2025-5-9 两个list 合并在一起管理,只会有一个List 有值
        self.work_status_lock = threading.Lock()

    def __getstate__(self):
        state = self.__dict__.copy()
        for attr in ('work_status_lock',):
            state.pop(attr, None)
        return state

    def __setstate__(self, state):
        # 恢复其余字段
        self.__dict__.update(state)
        # 重建运行时用的锁
        self.work_status_lock = threading.Lock()

    #设置用户信息
    def set_user_info(self,cookie,ext_info):
        self.cookie = cookie
        self.ext_info = ext_info

    #添加子节点
    def add_child(self, child_node):
        child_node.parent = self
        child_node.path = self.path + f"->{child_node.name}"    #子节点的路径赋值
        child_node.step_num = self.step_num
        self.children.append(child_node)

    #---------------------messages处理--------------------
    def copy_messages(self,p_msg,c_msg): #2025-4-23修改用做给本节点加msg
        '''
        当前节点添加mesg,约束:p_msg除system只取两层,c_msg:只取最后两个
        :param p_msg: 传递指令过来的节点的lis-msg --取传递过来的节点和parent
        :param c_msg: 传递指令过来的节点的当前节点-msg
        :return:
        '''
        if not p_msg or not c_msg:
            print("Messages存储存在问题!需要检查逻辑!")
            return
        tmp_pmsg = copy.deepcopy(p_msg)
        tmp_cmsg = copy.deepcopy(c_msg)
        if not self.parent.parent:  #正常来说self.parent肯定会有的,root节点不会触发copy_messages
            #如果路径不超过两级,pmsg都保留
            self.parent_messages = tmp_pmsg
        else: #如果超过两级了,就需要判断了
            pp_node_name = self.parent.parent.name
            bfind = False
            for msg in tmp_pmsg:
                if msg["role"] == "system":
                    self.parent_messages.append(msg)
                else:
                    if not bfind:
                        if msg["role"] == "user":
                            content = msg["content"]
                            if pp_node_name in content: #节点名称在内容中就代表跟该节点相关
                                bfind = True
                                self.parent_messages.append(msg)
                    else:#当二级父节点出行后,后面数据应该都是二级内的数据
                        self.parent_messages.append(msg)
        #cmsg --取最后两轮
        if len(tmp_cmsg) <=4: #cmsg全收
            self.parent_messages.extend(tmp_cmsg)
        else:
            isart = len(tmp_cmsg) - 4  #正常应该都是两个两个
            if isart % 2 != 0:
                print("c_msg数量不对称,需要检查逻辑!")
            for msg in tmp_cmsg[isart:]:
                self.parent_messages.append(msg)

    def updatemsg(self,newtype,newcontent,p_msg,c_msg,index=0):   #index待处理,目前待提交状态时,只应该有一条待提交数据
        with self.work_status_lock:
            if self._work_status == 0:  #新增指令
                if not self._llm_quere:
                    #判断是否要copy-父节点msg
                    if not self.parent_messages:
                        self.copy_messages(p_msg,c_msg)
                    newmsg = {"llm_type": int(newtype), "result": newcontent}
                    self._llm_quere.append(newmsg)
                    # 更新节点状态
                    self._work_status = 3  # 待提交
                else:
                    return False,"新增指令,待提交数据不应该有数据"
            elif self._work_status == 3:   #只允许待提交状态修改msg
                if self._llm_quere:
                    oldmsg = self._llm_quere[0]
                    oldmsg_llm_type = oldmsg["llm_type"]  # llm_type不允许修改
                    newmsg = {"llm_type": int(oldmsg_llm_type), "result": newcontent}
                    self._llm_quere[0] = newmsg
                else:
                    return False,"状态是待提交,不应该没有待提交数据"
            else:
                return False,"该状态,不运行修改待提交数据"
        return True,""

    def is_instr_empty(self):#待改 --根据work_status判断
        with self.work_status_lock:
            if self._instr_queue:
                return False
            return True

    def is_llm_empty(self):#待改  --根据work_status判断
        with self.work_status_lock:
            if self._llm_quere:
                return False
            return True

    #修改节点的执行状态--return bchange 只能是2-4
    def update_work_status(self,work_status):
        bsuccess = True
        with self.work_status_lock:
            if self._work_status == 0: #初始状态
                self._work_status = work_status
            else:
                if self._work_status == 1 and work_status == 2: #只允许从1-》2
                    self._work_status  = 2
                elif self._work_status == 3 and work_status == 4:#只允许从3-》4
                    self._work_status = 4
                elif self._work_status ==4 and work_status == 0: #4->0
                    self._work_status = 0
                elif work_status == -1:
                    self._work_status = 0
                elif work_status == -2:
                    self._work_status = 2
                elif work_status == -3:
                    self._work_status = 4
                elif work_status == -4: #测试调用
                    self._work_status = 1
                else:
                    bsuccess = False
        return bsuccess

    def get_work_status(self):
        #加锁有没有意义---web端和本身的工作线程会有同步问题
        work_status = self._work_status
        return work_status

    def add_instr(self,instr,p_msg,c_msg):    #所有指令一次提交
        if instr:
            with self.work_status_lock:
                if not self.parent_messages:   #为空时赋值
                    self.copy_messages(p_msg,c_msg)
                if self._work_status in (0,1,4):
                    self._instr_queue.append(instr)
                    self._work_status = 1   #待执行
                    return True
                else:
                    print("插入指令时,状态不为-1,1,4!")
                    return False,"节点的工作状态不是0或4,请检查程序逻辑"
        else:
            return False,"指令数据为空"

    def test_add_instr(self, instr):
        self._instr_queue.append(instr)
        self._llm_quere = []

    def get_instr(self):
        with self.work_status_lock:
            if self._work_status == 2:   #执行中
                return self._instr_queue.pop(0) if self._instr_queue else None
            else:
                print("不是执行中,不应该来取指令!")
                return None

    def del_instr(self,instr):  #web端,手动删除指令
        with self.work_status_lock:
            if self._work_status == 1:
                if instr in self._instr_queue:
                    self._instr_queue.remove(instr)
                    #指令删除后要判断是否清空指令了
                    if not self._instr_queue:
                        self._work_status = 0  #状态调整为无待执行任务
                    return True,""
                else:
                    return False,"该指令不在队列中!"
            else:
                return  False,"只有待执行时,允许删除指令"

    def add_res(self,str_res,itype =0): #llm_queue入库的情况比较多,2,0,4
        if str_res:
            with self.work_status_lock:
                if self._work_status in (2,0,4):
                    if itype == 1:  #要插入到第一个
                        tmplist = []
                        tmplist.append(str_res)
                        tmplist.extend(self._llm_quere)
                        self._llm_quere = tmplist
                    else:
                        self._llm_quere.append(str_res)
                    if self._work_status in (2,0):  #提交中,不要改变执行状态
                        self._work_status =3
                else:
                    print("添加llm数据时,状态不是-1,0,2,4中的一种情况")
                    return False,"添加llm数据时,状态不是-1,0,2,4中的一种情况"
        else:
            return False,"待提交llm的数据为空"

    def get_res(self):
        with self.work_status_lock:
            if self._work_status ==4:   #提交中
                return self._llm_quere.pop(0) if self._llm_quere else None
            else:
                print("不是提交中,不应该来取待提交数据!")
                return None

    def clear_res(self):
        with self.work_status_lock:
            self._llm_quere.clear()

    #-----------web查看数据-----------
    def get_instr_user(self):   #读不用锁了 -有错误问题不大
        with self.work_status_lock:
            instr_que = self._instr_queue.copy()
        return instr_que

    def get_res_user(self): #读不用锁了 -- 有错误问题不大
        with self.work_status_lock:
            llm_que = self._llm_quere.copy()
        return llm_que

    def __repr__(self):
        return f"TreeNode({self.name}, {self.status}, {self.vul_type})"

if __name__ == "__main__":
    pass