import pymysql import sqlite3 import threading import os import json from myutils.ConfigManager import myCongif from myutils.MyLogger_logger import LogHandler from myutils.MyTime import get_local_timestr from datetime import timedelta from datetime import datetime, timedelta class DBManager: #实例化数据库管理对象,并连接数据库 #itype=0 使用mysql数据库,1-使用sqlite数据库 def __init__(self): self.logger = LogHandler().get_logger("DBManager") self.lock = threading.Lock() self.itype = myCongif.get_data("DBType") self.ok = False if self.itype ==0: self.host = myCongif.get_data('mysql.host') self.port = myCongif.get_data('mysql.port') self.user = myCongif.get_data('mysql.user') self.passwd = myCongif.get_data('mysql.passwd') self.database = myCongif.get_data('mysql.database') self.connection = None elif self.itype ==1: self.dbfile = myCongif.get_data("sqlite") if not os.path.exists(self.dbfile): self.dbfile = "../" + self.dbfile #直接运行DBManager时初始路径不是在根目录 if not os.path.exists(self.dbfile): raise FileNotFoundError(f"Database file {self.dbfile} does not exist.") else: self.logger.error("错误的数据库类型,请检查") def __del__(self): if self.ok: self.connection.close() self.connection = None self.logger.debug("DBManager销毁") def connect(self): try: if self.itype ==0: self.connection = pymysql.connect(host=self.host, port=self.port, user=self.user, passwd=self.passwd, db=self.database,charset='utf8') elif self.itype ==1: self.connection = sqlite3.connect(self.dbfile) self.ok = True self.logger.debug("服务器端数据库连接成功") return True except: self.logger.error("服务器端数据库连接失败") return False # 判断数据库连接是否正常,若不正常则重连接 def Retest_conn(self): if self.itype == 0: #除了mysql,sqlite3不需要判断连接状态 try: self.connection.ping() except: return self.connect() return True # 执行数据库查询操作 1-只查询一条记录,其他所有记录 def do_select(self, strsql, itype=0): # self.conn.begin() self.lock.acquire() data = None if self.Retest_conn(): try: self.connection.commit() # select要commit提交事务,是存在获取不到最新数据的问题(innoDB事务机制) with self.connection.cursor() as cursor: cursor.execute(strsql) if itype == 1: data = cursor.fetchone() else: data = cursor.fetchall() except Exception as e: self.logger.error("do_select异常报错:%s" % str(e)) self.lock.release() return None self.lock.release() return data # 执行数据库语句 def do_sql(self, strsql, data=None): bok = False self.lock.acquire() if self.Retest_conn(): try: with self.connection.cursor() as cursor: # self.conn.begin() if data: iret = cursor.executemany(strsql, data) #批量执行sql语句 else: iret = cursor.execute(strsql) self.connection.commit() bok = True except Exception as e: self.logger.error("执行数据库语句%s出错:%s" % (strsql, str(e))) self.connection.rollback() self.lock.release() return bok def safe_do_sql(self,strsql,params,itype=0): bok = False task_id = 0 self.lock.acquire() if self.Retest_conn(): try: with self.connection.cursor() as cursor: cursor.execute(strsql, params) self.connection.commit() if itype ==1: #只有插入task任务数据的时候是1 task_id = cursor.lastrowid bok = True except Exception as e: self.logger.error("执行数据库语句%s出错:%s" % (strsql, str(e))) self.connection.rollback() self.lock.release() return bok,task_id def safe_do_select(self,strsql,params,itype=0): results = [] self.lock.acquire() if self.Retest_conn(): self.connection.commit() try: with self.connection.cursor() as cursor: cursor.execute(strsql, params) # 执行参数化查询 if itype ==0: results = cursor.fetchall() # 获取所有结果 elif itype ==1: results = cursor.fetchone() #获得一条记录 except Exception as e: print(f"查询出错: {e}") self.lock.release() return results def is_json(self,s:str) -> bool: if not isinstance(s, str): return False try: json.loads(s) return True except json.JSONDecodeError: return False except Exception: return False # 处理其他意外异常(如输入 None) def timedelta_to_str(delta: timedelta) -> str: hours, remainder = divmod(delta.total_seconds(), 3600) minutes, seconds = divmod(remainder, 60) return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}" #---------------------特定数据库操作函数--------------------- def get_system_info(self): strsql = "select * from zf_system;" data = self.do_select(strsql,1) return data def get_run_tasks(self): strsql = "select ID,task_target,task_status,work_type,cookie_info,llm_type from task where task_status <> 2;" datas = self.do_select(strsql) return datas def start_task(self,test_target,cookie_info,work_type,llm_type) -> int: ''' 数据库添加检测任务 :param task_name: :param task_target: :return: task_id ''' task_id =0 start_time = get_local_timestr() sql = "INSERT INTO task (task_name,task_target,start_time,task_status,safe_rank,work_type,cookie_info,llm_type) " \ "VALUES (%s,%s,%s,%s,%s,%s,%s,%s)" params = (test_target,test_target,start_time,1,0,work_type,cookie_info,llm_type) bok,task_id = self.safe_do_sql(sql,params,1) return task_id def over_task(self,task_id): strsql = "update task set task_status=2 where ID=%s;" params = (task_id) bok,_ = self.safe_do_sql(strsql, params) return bok def del_task(self,task_id): params = (task_id) strsql = "delete from task where ID=%s;" bok,_ = self.safe_do_sql(strsql,params) strsql = "delete from task_llm where task_id=%s;" bok, _ = self.safe_do_sql(strsql, params) strsql = "delete from task_result where task_id=%s;" bok, _ = self.safe_do_sql(strsql, params) strsql = "delete from task_vul where task_id=%s;" bok, _ = self.safe_do_sql(strsql, params) return bok #指令执行结果入库 def insetr_result(self,task_id,instruction,result,do_sn,start_time,end_time,source_result,ext_params,node_path): str_result = "" str_source_result = "" # 统一将 result 转为 JSON 字符串(无论原始类型) try: if not isinstance(result, str): str_result = json.dumps(result, ensure_ascii=False) else: # 如果是字符串,先验证是否为合法 JSON(可选) json.loads(result) str_result = result except (TypeError, json.JSONDecodeError): str_result = json.dumps(str(result),ensure_ascii=False) # 兜底处理非 JSON 字符串 try: if not isinstance(source_result, str): str_source_result = json.dumps(source_result, ensure_ascii=False) else: # 如果是字符串,先验证是否为合法 JSON(可选) json.loads(source_result) str_source_result = source_result except (TypeError, json.JSONDecodeError): str_source_result = json.dumps(str(source_result),ensure_ascii=False) # 兜底处理非 JSON 字符串 # 使用参数化查询 sql = """ INSERT INTO task_result (task_id, instruction, result, do_sn,start_time,end_time,source_result,is_user,is_vulnerability,node_path) VALUES (%s, %s, %s, %s, %s, %s,%s,%s,%s,%s) """ params = (task_id, instruction, str_result, do_sn,start_time,end_time,source_result,ext_params['is_user'], ext_params['is_vulnerability'],node_path) bok,_ = self.safe_do_sql(sql,params) return bok #llm数据入库 def insert_llm(self,task_id,prompt,reasoning_content,content,post_time,llm_sn,path): str_reasoning = "" str_content = "" try: if not isinstance(reasoning_content, str): str_reasoning = json.dumps(reasoning_content) #,ensure_ascii=False else: # 如果是字符串,先验证是否为合法 JSON(可选) json.loads(reasoning_content) str_reasoning = reasoning_content except (TypeError, json.JSONDecodeError): str_reasoning = json.dumps(str(reasoning_content)) # 兜底处理非 JSON 字符串 try: if not isinstance(content, str): str_content = json.dumps(content) else: # 如果是字符串,先验证是否为合法 JSON(可选) json.loads(content) str_content = content except (TypeError, json.JSONDecodeError): str_content = json.dumps(str(content)) # 兜底处理非 JSON 字符串 sql=""" INSERT INTO task_llm (task_id,do_sn,prompt,reasoning_content,content,start_time,node_path) VALUES (%s, %s, %s, %s, %s, %s,%s) """ str_reasoning = str_reasoning.encode('utf-8').decode('unicode_escape') str_content = str_content.encode('utf-8').decode('unicode_escape') params = (task_id,llm_sn,prompt,str_reasoning,str_content,post_time,path) bok,_=self.safe_do_sql(sql,params) return bok #获取任务的测试指令执行情况 def get_task_instrs(self,task_id,nodename): strsql = ''' select ID,node_path,do_sn,instruction,result from task_result where task_id = %s ''' if nodename.strip(): strsql += " and nodename like %s;" params = (task_id,nodename) else: strsql += ";" params = (task_id) datas = self.safe_do_select(strsql,params) return datas #插入漏洞数据 def insert_taks_vul(self,task_id,node_name,node_path,vul_type,vul_level,vul_info): strsql = ''' INSERT INTO task_vul (task_id,node_name,node_path,vul_type,vul_level,vul_info) VALUES (%s,%s,%s,%s,%s,%s) ''' params = (task_id,node_name,node_path,vul_type,vul_level,vul_info) bok,_ = self.safe_do_sql(strsql,params) return bok #获取任务的漏洞检测情况 def get_task_vul(self,task_id,nodename,vultype,vullevel): strsql = ''' select ID,node_path,vul_type,vul_level,vul_info from task_vul ''' # 动态构建查询条件 conditions = ["task_id=%s"] # task_id 必须存在 params = [task_id] # 参数列表初始化 # 按需添加其他条件 if nodename and nodename.strip(): # 检查nodename是否非空(去除前后空格后) conditions.append("node_path=%s") params.append(nodename) if vultype and vultype.strip(): # 检查vultype是否非空 conditions.append("vul_type=%s") params.append(vultype) if vullevel and vullevel.strip(): # 检查vullevel是否非空 conditions.append("vul_level=%s") params.append(vullevel) # 组合完整的WHERE子句 if len(conditions) > 0: strsql += " WHERE " + " AND ".join(conditions) # 执行查询(将参数转为元组) datas = self.safe_do_select(strsql, tuple(params)) return datas #获取该任务该节点的所有 已经执行的任务 def get_task_node_done_instr(self,task_id,nodepath): strsql = ''' select instruction,start_time,result from task_result where task_id=%s and node_path=%s order by start_time desc; ''' params = (task_id,nodepath) datas = self.safe_do_select(strsql,params) return datas def get_his_tasks(self,target_name,safe_rank,llm_type,start_time,end_time): strsql = "select ID,task_target,safe_rank,llm_type,start_time,end_time from task" conditions = ["task_status=%s"] params = [2] # 按需添加其他条件 if target_name and target_name.strip(): # 检查nodename是否非空(去除前后空格后) conditions.append("task_target=%s") params.append(target_name) if safe_rank and safe_rank.strip(): # 检查vultype是否非空 conditions.append("safe_rank=%s") params.append(safe_rank) if llm_type and llm_type.strip(): # 检查vullevel是否非空 conditions.append("llm_type=%s") params.append(llm_type) if start_time and start_time.strip(): # 检查vultype是否非空 conditions.append("start_time >= %s") start_date = datetime.strptime(start_time, "%Y-%m-%d") # 生成起始时间字符串(当日 00:00:00) start_time_str = start_date.strftime("%Y-%m-%d 00:00:00") params.append(start_time_str) if end_time and end_time.strip(): # 检查vullevel是否非空 conditions.append("end_time < %s") # 将输入字符串转为日期对象 end_date = datetime.strptime(end_time, "%Y-%m-%d") # 生成结束时间字符串(次日 00:00:00) end_time_str = (end_date + timedelta(days=1)).strftime("%Y-%m-%d 00:00:00") params.append(end_time_str) # 组合完整的WHERE子句 if len(conditions) > 0: strsql += " WHERE " + " AND ".join(conditions) # 执行查询(将参数转为元组) datas = self.safe_do_select(strsql, tuple(params)) return datas def getsystem_info(self): strsql = "select local_ip,version from zf_system;" data = self.do_select(strsql,1) return data def update_localip(self,local_ip): strsql = "update zf_system set local_ip=%s;" params = (local_ip) bok,_ = self.safe_do_sql(strsql,params) return bok def test(self): # 建立数据库连接 conn = pymysql.connect( host='localhost', port=3306, user='username', password='password', database='database_name' ) # 创建游标对象 cursor = conn.cursor() # 执行 SQL 查询 query = "SELECT * FROM table_name" cursor.execute(query) # 获取查询结果 result = cursor.fetchall() # 输出结果 for row in result: print(row) # 关闭游标和连接 cursor.close() conn.close() #全局的单一实例 app_DBM = DBManager() app_DBM.connect() if __name__ == "__main__": mDBM = DBManager() mDBM.connect() print(mDBM.start_task("11","22"))