import pymysql
import sqlite3
import threading
import os
import json
from myutils.ConfigManager import myCongif
from myutils.MyLogger_logger import LogHandler
from myutils.MyTime import get_local_timestr
from datetime import timedelta

class DBManager:
    #实例化数据库管理对象,并连接数据库
    #itype=0  使用mysql数据库,1-使用sqlite数据库
    def __init__(self):
        self.logger = LogHandler().get_logger("DBManager")
        self.lock = threading.Lock()
        self.itype = myCongif.get_data("DBType")
        self.ok = False
        if self.itype ==0:
            self.host = myCongif.get_data('mysql.host')
            self.port = myCongif.get_data('mysql.port')
            self.user = myCongif.get_data('mysql.user')
            self.passwd = myCongif.get_data('mysql.passwd')
            self.database = myCongif.get_data('mysql.database')
            self.connection = None
        elif self.itype ==1:
            self.dbfile = myCongif.get_data("sqlite")
            if not os.path.exists(self.dbfile):
                self.dbfile = "../" + self.dbfile   #直接运行DBManager时初始路径不是在根目录
                if not os.path.exists(self.dbfile):
                    raise FileNotFoundError(f"Database file {self.dbfile} does not exist.")
        else:
            self.logger.error("错误的数据库类型,请检查")

    def __del__(self):
        if self.ok:
            self.connection.close()
            self.connection = None
        self.logger.debug("DBManager销毁")

    def connect(self):
        try:
            if self.itype ==0:
                self.connection = pymysql.connect(host=self.host, port=self.port, user=self.user,
                                                  passwd=self.passwd, db=self.database,charset='utf8')
            elif self.itype ==1:
                self.connection = sqlite3.connect(self.dbfile)
            self.ok = True
            self.logger.debug("服务器端数据库连接成功")
            return True
        except:
            self.logger.error("服务器端数据库连接失败")
            return False

    # 判断数据库连接是否正常,若不正常则重连接
    def Retest_conn(self):
        if self.itype == 0: #除了mysql,sqlite3不需要判断连接状态
            try:
                self.connection.ping()
            except:
                return self.connect()
        return True

    # 执行数据库查询操作 1-只查询一条记录,其他所有记录
    def do_select(self, strsql, itype=0):
        # self.conn.begin()
        self.lock.acquire()
        data = None
        if self.Retest_conn():
            try:
                self.connection.commit()  # select要commit提交事务,是存在获取不到最新数据的问题(innoDB事务机制)
                with self.connection.cursor() as cursor:
                    cursor.execute(strsql)
                    if itype == 1:
                        data = cursor.fetchone()
                    else:
                        data = cursor.fetchall()
            except Exception as e:
                self.logger.error("do_select异常报错:%s" % str(e))
                self.lock.release()
                return None
        self.lock.release()
        return data

    # 执行数据库语句
    def do_sql(self, strsql, data=None):
        bok = False
        self.lock.acquire()
        if self.Retest_conn():
            try:
                with self.connection.cursor() as cursor:
                    # self.conn.begin()
                    if data:
                        iret = cursor.executemany(strsql, data)    #批量执行sql语句
                    else:
                        iret = cursor.execute(strsql)
                    self.connection.commit()
                    bok = True
            except Exception as e:
                self.logger.error("执行数据库语句%s出错:%s" % (strsql, str(e)))
                self.connection.rollback()
        self.lock.release()
        return bok

    def safe_do_sql(self,strsql,params,itype=0):
        bok = False
        task_id = 0
        self.lock.acquire()
        if self.Retest_conn():
            try:
                with self.connection.cursor() as cursor:
                    cursor.execute(strsql, params)
                    self.connection.commit()
                    if itype ==1:
                        task_id = cursor.lastrowid
                    bok = True
            except Exception as e:
                self.logger.error("执行数据库语句%s出错:%s" % (strsql, str(e)))
                self.connection.rollback()
        self.lock.release()
        return bok,task_id

    def safe_do_select(self,strsql,params,itype=0):
        results = []
        self.lock.acquire()
        if self.Retest_conn():
            self.connection.commit()
            try:
                with self.connection.cursor() as cursor:
                    cursor.execute(strsql, params)  # 执行参数化查询
                    if itype ==0:
                        results = cursor.fetchall()  # 获取所有结果
                    elif itype ==1:
                        results = cursor.fetchone()    #获得一条记录
            except Exception as e:
                print(f"查询出错: {e}")
        self.lock.release()
        return results

    def is_json(self,s:str) -> bool:
        if not isinstance(s, str):
            return False
        try:
            json.loads(s)
            return True
        except json.JSONDecodeError:
            return False
        except Exception:
            return False  # 处理其他意外异常(如输入 None)

    def timedelta_to_str(delta: timedelta) -> str:
        hours, remainder = divmod(delta.total_seconds(), 3600)
        minutes, seconds = divmod(remainder, 60)
        return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"

    #---------------------特定数据库操作函数---------------------
    def get_system_info(self):
        strsql = "select * from zf_system;"
        data = self.do_select(strsql,1)
        return data

    def get_run_tasks(self):
        strsql = "select ID,task_target,task_status,work_type,cookie_info,llm_type from task where task_status <> 2;"
        datas = self.do_select(strsql)
        return datas

    def start_task(self,test_target,cookie_info,work_type,llm_type) -> int:
        '''
        数据库添加检测任务
        :param task_name:
        :param task_target:
        :return:  task_id
        '''
        task_id =0
        start_time = get_local_timestr()
        sql = "INSERT INTO task (task_name,task_target,start_time,task_status,safe_rank,work_type,cookie_info,llm_type) " \
              "VALUES (%s,%s,%s,%s,%s,%s,%s,%s)"
        params = (test_target,test_target,start_time,1,0,work_type,cookie_info,llm_type)
        bok,task_id = self.safe_do_sql(sql,params,1)
        return task_id

    def over_task(self,task_id):
        strsql = "update task set task_status=2 where ID=%s;"
        params = (task_id)
        bok,_ = self.safe_do_sql(strsql, params)
        return bok

    #指令执行结果入库
    def insetr_result(self,task_id,instruction,result,do_sn,start_time,end_time,source_result,ext_params,node_path):
        str_result = ""
        str_source_result = ""
        # 统一将 result 转为 JSON 字符串(无论原始类型)
        try:
            if not isinstance(result, str):
                str_result = json.dumps(result, ensure_ascii=False)
            else:
                # 如果是字符串,先验证是否为合法 JSON(可选)
                json.loads(result)
                str_result = result
        except (TypeError, json.JSONDecodeError):
            str_result = json.dumps(str(result),ensure_ascii=False)  # 兜底处理非 JSON 字符串

        try:
            if not isinstance(source_result, str):
                str_source_result = json.dumps(source_result, ensure_ascii=False)
            else:
                # 如果是字符串,先验证是否为合法 JSON(可选)
                json.loads(source_result)
                str_source_result = source_result
        except (TypeError, json.JSONDecodeError):
            str_source_result = json.dumps(str(source_result),ensure_ascii=False)  # 兜底处理非 JSON 字符串

        # 使用参数化查询
        sql = """
            INSERT INTO task_result 
                (task_id, instruction, result, do_sn,start_time,end_time,source_result,is_user,is_vulnerability,node_path)
            VALUES 
                (%s, %s, %s, %s, %s, %s,%s,%s,%s,%s)
        """
        params = (task_id, instruction, str_result, do_sn,start_time,end_time,source_result,ext_params['is_user'],
                  ext_params['is_vulnerability'],node_path)
        bok,_ =  self.safe_do_sql(sql,params)
        return bok

    #llm数据入库
    def insert_llm(self,task_id,prompt,reasoning_content,content,post_time,llm_sn,path):
        str_reasoning = ""
        str_content = ""
        try:
            if not isinstance(reasoning_content, str):
                str_reasoning = json.dumps(reasoning_content)       #,ensure_ascii=False
            else:
                # 如果是字符串,先验证是否为合法 JSON(可选)
                json.loads(reasoning_content)
                str_reasoning = reasoning_content
        except (TypeError, json.JSONDecodeError):
            str_reasoning = json.dumps(str(reasoning_content))  # 兜底处理非 JSON 字符串

        try:
            if not isinstance(content, str):
                str_content = json.dumps(content)
            else:
                # 如果是字符串,先验证是否为合法 JSON(可选)
                json.loads(content)
                str_content = content
        except (TypeError, json.JSONDecodeError):
            str_content = json.dumps(str(content))  # 兜底处理非 JSON 字符串

        sql="""
            INSERT INTO task_llm 
                (task_id,do_sn,prompt,reasoning_content,content,start_time,node_path)
            VALUES 
                (%s, %s, %s, %s, %s, %s,%s)
        """
        str_reasoning = str_reasoning.encode('utf-8').decode('unicode_escape')
        str_content = str_content.encode('utf-8').decode('unicode_escape')

        params = (task_id,llm_sn,prompt,str_reasoning,str_content,post_time,path)
        bok,_=self.safe_do_sql(sql,params)
        return bok

    #获取任务的测试指令执行情况
    def get_task_instrs(self,task_id,nodename):
        instrs = []
        return instrs

    #获取任务的漏洞检测情况
    def get_task_vul(self,task_id,nodename,vultype,vullevel):
        vuls =[]
        return vuls

    #获取该任务该节点的所有 已经执行的任务
    def get_task_node_done_instr(self,task_id,nodepath):
        strsql = '''
        select instruction,start_time,result from task_result where task_id=%s and node_path=%s;
        '''
        params = (task_id,nodepath)
        datas = self.safe_do_select(strsql,params)
        return datas

    def test(self):
        # 建立数据库连接
        conn = pymysql.connect(
            host='localhost',
            port=3306,
            user='username',
            password='password',
            database='database_name'
        )

        # 创建游标对象
        cursor = conn.cursor()

        # 执行 SQL 查询
        query = "SELECT * FROM table_name"
        cursor.execute(query)

        # 获取查询结果
        result = cursor.fetchall()

        # 输出结果
        for row in result:
            print(row)

        # 关闭游标和连接
        cursor.close()
        conn.close()

#全局的单一实例
app_DBM = DBManager()
app_DBM.connect()

if __name__ == "__main__":
    mDBM = DBManager()
    mDBM.connect()
    print(mDBM.start_task("11","22"))