|
|
|
import pymysql
|
|
|
|
import sqlite3
|
|
|
|
import threading
|
|
|
|
import os
|
|
|
|
import json
|
|
|
|
from myutils.ConfigManager import myCongif
|
|
|
|
from myutils.MyLogger_logger import LogHandler
|
|
|
|
from myutils.MyTime import get_local_timestr
|
|
|
|
from datetime import timedelta
|
|
|
|
|
|
|
|
class DBManager:
|
|
|
|
#实例化数据库管理对象,并连接数据库
|
|
|
|
#itype=0 使用mysql数据库,1-使用sqlite数据库
|
|
|
|
def __init__(self):
|
|
|
|
self.logger = LogHandler().get_logger("DBManager")
|
|
|
|
self.lock = threading.Lock()
|
|
|
|
self.itype = myCongif.get_data("DBType")
|
|
|
|
self.ok = False
|
|
|
|
if self.itype ==0:
|
|
|
|
self.host = myCongif.get_data('mysql.host')
|
|
|
|
self.port = myCongif.get_data('mysql.port')
|
|
|
|
self.user = myCongif.get_data('mysql.user')
|
|
|
|
self.passwd = myCongif.get_data('mysql.passwd')
|
|
|
|
self.database = myCongif.get_data('mysql.database')
|
|
|
|
self.connection = None
|
|
|
|
elif self.itype ==1:
|
|
|
|
self.dbfile = myCongif.get_data("sqlite")
|
|
|
|
if not os.path.exists(self.dbfile):
|
|
|
|
self.dbfile = "../" + self.dbfile #直接运行DBManager时初始路径不是在根目录
|
|
|
|
if not os.path.exists(self.dbfile):
|
|
|
|
raise FileNotFoundError(f"Database file {self.dbfile} does not exist.")
|
|
|
|
else:
|
|
|
|
self.logger.error("错误的数据库类型,请检查")
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
if self.ok:
|
|
|
|
self.connection.close()
|
|
|
|
self.connection = None
|
|
|
|
self.logger.debug("DBManager销毁")
|
|
|
|
|
|
|
|
def connect(self):
|
|
|
|
try:
|
|
|
|
if self.itype ==0:
|
|
|
|
self.connection = pymysql.connect(host=self.host, port=self.port, user=self.user,
|
|
|
|
passwd=self.passwd, db=self.database,charset='utf8')
|
|
|
|
elif self.itype ==1:
|
|
|
|
self.connection = sqlite3.connect(self.dbfile)
|
|
|
|
self.ok = True
|
|
|
|
self.logger.debug("服务器端数据库连接成功")
|
|
|
|
return True
|
|
|
|
except:
|
|
|
|
self.logger.error("服务器端数据库连接失败")
|
|
|
|
return False
|
|
|
|
|
|
|
|
# 判断数据库连接是否正常,若不正常则重连接
|
|
|
|
def Retest_conn(self):
|
|
|
|
if self.itype == 0: #除了mysql,sqlite3不需要判断连接状态
|
|
|
|
try:
|
|
|
|
self.connection.ping()
|
|
|
|
except:
|
|
|
|
return self.connect()
|
|
|
|
return True
|
|
|
|
|
|
|
|
# 执行数据库查询操作 1-只查询一条记录,其他所有记录
|
|
|
|
def do_select(self, strsql, itype=0):
|
|
|
|
# self.conn.begin()
|
|
|
|
self.lock.acquire()
|
|
|
|
data = None
|
|
|
|
if self.Retest_conn():
|
|
|
|
try:
|
|
|
|
self.connection.commit() # select要commit提交事务,是存在获取不到最新数据的问题(innoDB事务机制)
|
|
|
|
with self.connection.cursor() as cursor:
|
|
|
|
cursor.execute(strsql)
|
|
|
|
if itype == 1:
|
|
|
|
data = cursor.fetchone()
|
|
|
|
else:
|
|
|
|
data = cursor.fetchall()
|
|
|
|
except Exception as e:
|
|
|
|
self.logger.error("do_select异常报错:%s" % str(e))
|
|
|
|
self.lock.release()
|
|
|
|
return None
|
|
|
|
self.lock.release()
|
|
|
|
return data
|
|
|
|
|
|
|
|
# 执行数据库语句
|
|
|
|
def do_sql(self, strsql, data=None):
|
|
|
|
bok = False
|
|
|
|
self.lock.acquire()
|
|
|
|
if self.Retest_conn():
|
|
|
|
try:
|
|
|
|
with self.connection.cursor() as cursor:
|
|
|
|
# self.conn.begin()
|
|
|
|
if data:
|
|
|
|
iret = cursor.executemany(strsql, data) #批量执行sql语句
|
|
|
|
else:
|
|
|
|
iret = cursor.execute(strsql)
|
|
|
|
self.connection.commit()
|
|
|
|
bok = True
|
|
|
|
except Exception as e:
|
|
|
|
self.logger.error("执行数据库语句%s出错:%s" % (strsql, str(e)))
|
|
|
|
self.connection.rollback()
|
|
|
|
self.lock.release()
|
|
|
|
return bok
|
|
|
|
|
|
|
|
def safe_do_sql(self,strsql,params,itype=0):
|
|
|
|
bok = False
|
|
|
|
task_id = 0
|
|
|
|
self.lock.acquire()
|
|
|
|
if self.Retest_conn():
|
|
|
|
try:
|
|
|
|
with self.connection.cursor() as cursor:
|
|
|
|
cursor.execute(strsql, params)
|
|
|
|
self.connection.commit()
|
|
|
|
if itype ==1:
|
|
|
|
task_id = cursor.lastrowid
|
|
|
|
bok = True
|
|
|
|
except Exception as e:
|
|
|
|
self.logger.error("执行数据库语句%s出错:%s" % (strsql, str(e)))
|
|
|
|
self.connection.rollback()
|
|
|
|
self.lock.release()
|
|
|
|
return bok,task_id
|
|
|
|
|
|
|
|
def safe_do_select(self,strsql,params,itype=0):
|
|
|
|
results = []
|
|
|
|
self.lock.acquire()
|
|
|
|
if self.Retest_conn():
|
|
|
|
self.connection.commit()
|
|
|
|
try:
|
|
|
|
with self.connection.cursor() as cursor:
|
|
|
|
cursor.execute(strsql, params) # 执行参数化查询
|
|
|
|
if itype ==0:
|
|
|
|
results = cursor.fetchall() # 获取所有结果
|
|
|
|
elif itype ==1:
|
|
|
|
results = cursor.fetchone() #获得一条记录
|
|
|
|
except Exception as e:
|
|
|
|
print(f"查询出错: {e}")
|
|
|
|
self.lock.release()
|
|
|
|
return results
|
|
|
|
|
|
|
|
def is_json(self,s:str) -> bool:
|
|
|
|
if not isinstance(s, str):
|
|
|
|
return False
|
|
|
|
try:
|
|
|
|
json.loads(s)
|
|
|
|
return True
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
return False
|
|
|
|
except Exception:
|
|
|
|
return False # 处理其他意外异常(如输入 None)
|
|
|
|
|
|
|
|
def timedelta_to_str(delta: timedelta) -> str:
|
|
|
|
hours, remainder = divmod(delta.total_seconds(), 3600)
|
|
|
|
minutes, seconds = divmod(remainder, 60)
|
|
|
|
return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"
|
|
|
|
|
|
|
|
#---------------------特定数据库操作函数---------------------
|
|
|
|
def get_system_info(self):
|
|
|
|
strsql = "select * from zf_system;"
|
|
|
|
data = self.do_select(strsql,1)
|
|
|
|
return data
|
|
|
|
|
|
|
|
def get_run_tasks(self):
|
|
|
|
strsql = "select ID,task_target,task_status,work_type,cookie_info,llm_type from task where task_status <> 2;"
|
|
|
|
datas = self.do_select(strsql)
|
|
|
|
return datas
|
|
|
|
|
|
|
|
def start_task(self,test_target,cookie_info,work_type,llm_type) -> int:
|
|
|
|
'''
|
|
|
|
数据库添加检测任务
|
|
|
|
:param task_name:
|
|
|
|
:param task_target:
|
|
|
|
:return: task_id
|
|
|
|
'''
|
|
|
|
task_id =0
|
|
|
|
start_time = get_local_timestr()
|
|
|
|
sql = "INSERT INTO task (task_name,task_target,start_time,task_status,safe_rank,work_type,cookie_info,llm_type) " \
|
|
|
|
"VALUES (%s,%s,%s,%s,%s,%s,%s,%s)"
|
|
|
|
params = (test_target,test_target,start_time,1,0,work_type,cookie_info,llm_type)
|
|
|
|
bok,task_id = self.safe_do_sql(sql,params,1)
|
|
|
|
return task_id
|
|
|
|
|
|
|
|
def over_task(self,task_id):
|
|
|
|
strsql = "update task set task_status=2 where ID=%s;"
|
|
|
|
params = (task_id)
|
|
|
|
bok,_ = self.safe_do_sql(strsql, params)
|
|
|
|
return bok
|
|
|
|
|
|
|
|
#指令执行结果入库
|
|
|
|
def insetr_result(self,task_id,instruction,result,do_sn,start_time,end_time,source_result,ext_params,node_path):
|
|
|
|
str_result = ""
|
|
|
|
str_source_result = ""
|
|
|
|
# 统一将 result 转为 JSON 字符串(无论原始类型)
|
|
|
|
try:
|
|
|
|
if not isinstance(result, str):
|
|
|
|
str_result = json.dumps(result, ensure_ascii=False)
|
|
|
|
else:
|
|
|
|
# 如果是字符串,先验证是否为合法 JSON(可选)
|
|
|
|
json.loads(result)
|
|
|
|
str_result = result
|
|
|
|
except (TypeError, json.JSONDecodeError):
|
|
|
|
str_result = json.dumps(str(result),ensure_ascii=False) # 兜底处理非 JSON 字符串
|
|
|
|
|
|
|
|
try:
|
|
|
|
if not isinstance(source_result, str):
|
|
|
|
str_source_result = json.dumps(source_result, ensure_ascii=False)
|
|
|
|
else:
|
|
|
|
# 如果是字符串,先验证是否为合法 JSON(可选)
|
|
|
|
json.loads(source_result)
|
|
|
|
str_source_result = source_result
|
|
|
|
except (TypeError, json.JSONDecodeError):
|
|
|
|
str_source_result = json.dumps(str(source_result),ensure_ascii=False) # 兜底处理非 JSON 字符串
|
|
|
|
|
|
|
|
# 使用参数化查询
|
|
|
|
sql = """
|
|
|
|
INSERT INTO task_result
|
|
|
|
(task_id, instruction, result, do_sn,start_time,end_time,source_result,is_user,is_vulnerability,node_path)
|
|
|
|
VALUES
|
|
|
|
(%s, %s, %s, %s, %s, %s,%s,%s,%s,%s)
|
|
|
|
"""
|
|
|
|
params = (task_id, instruction, str_result, do_sn,start_time,end_time,source_result,ext_params['is_user'],
|
|
|
|
ext_params['is_vulnerability'],node_path)
|
|
|
|
bok,_ = self.safe_do_sql(sql,params)
|
|
|
|
return bok
|
|
|
|
|
|
|
|
#llm数据入库
|
|
|
|
def insert_llm(self,task_id,prompt,reasoning_content,content,post_time,llm_sn,path):
|
|
|
|
str_reasoning = ""
|
|
|
|
str_content = ""
|
|
|
|
try:
|
|
|
|
if not isinstance(reasoning_content, str):
|
|
|
|
str_reasoning = json.dumps(reasoning_content) #,ensure_ascii=False
|
|
|
|
else:
|
|
|
|
# 如果是字符串,先验证是否为合法 JSON(可选)
|
|
|
|
json.loads(reasoning_content)
|
|
|
|
str_reasoning = reasoning_content
|
|
|
|
except (TypeError, json.JSONDecodeError):
|
|
|
|
str_reasoning = json.dumps(str(reasoning_content)) # 兜底处理非 JSON 字符串
|
|
|
|
|
|
|
|
try:
|
|
|
|
if not isinstance(content, str):
|
|
|
|
str_content = json.dumps(content)
|
|
|
|
else:
|
|
|
|
# 如果是字符串,先验证是否为合法 JSON(可选)
|
|
|
|
json.loads(content)
|
|
|
|
str_content = content
|
|
|
|
except (TypeError, json.JSONDecodeError):
|
|
|
|
str_content = json.dumps(str(content)) # 兜底处理非 JSON 字符串
|
|
|
|
|
|
|
|
sql="""
|
|
|
|
INSERT INTO task_llm
|
|
|
|
(task_id,do_sn,prompt,reasoning_content,content,start_time,node_path)
|
|
|
|
VALUES
|
|
|
|
(%s, %s, %s, %s, %s, %s,%s)
|
|
|
|
"""
|
|
|
|
str_reasoning = str_reasoning.encode('utf-8').decode('unicode_escape')
|
|
|
|
str_content = str_content.encode('utf-8').decode('unicode_escape')
|
|
|
|
|
|
|
|
params = (task_id,llm_sn,prompt,str_reasoning,str_content,post_time,path)
|
|
|
|
bok,_=self.safe_do_sql(sql,params)
|
|
|
|
return bok
|
|
|
|
|
|
|
|
#获取任务的测试指令执行情况
|
|
|
|
def get_task_instrs(self,task_id,nodename):
|
|
|
|
instrs = []
|
|
|
|
return instrs
|
|
|
|
|
|
|
|
#获取任务的漏洞检测情况
|
|
|
|
def get_task_vul(self,task_id,nodename,vultype,vullevel):
|
|
|
|
vuls =[]
|
|
|
|
return vuls
|
|
|
|
|
|
|
|
#获取该任务该节点的所有 已经执行的任务
|
|
|
|
def get_task_node_done_instr(self,task_id,nodepath):
|
|
|
|
strsql = '''
|
|
|
|
select instruction,start_time,result from task_result where task_id=%s and node_path=%s;
|
|
|
|
'''
|
|
|
|
params = (task_id,nodepath)
|
|
|
|
datas = self.safe_do_select(strsql,params)
|
|
|
|
return datas
|
|
|
|
|
|
|
|
def test(self):
|
|
|
|
# 建立数据库连接
|
|
|
|
conn = pymysql.connect(
|
|
|
|
host='localhost',
|
|
|
|
port=3306,
|
|
|
|
user='username',
|
|
|
|
password='password',
|
|
|
|
database='database_name'
|
|
|
|
)
|
|
|
|
|
|
|
|
# 创建游标对象
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
# 执行 SQL 查询
|
|
|
|
query = "SELECT * FROM table_name"
|
|
|
|
cursor.execute(query)
|
|
|
|
|
|
|
|
# 获取查询结果
|
|
|
|
result = cursor.fetchall()
|
|
|
|
|
|
|
|
# 输出结果
|
|
|
|
for row in result:
|
|
|
|
print(row)
|
|
|
|
|
|
|
|
# 关闭游标和连接
|
|
|
|
cursor.close()
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
#全局的单一实例
|
|
|
|
app_DBM = DBManager()
|
|
|
|
app_DBM.connect()
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
mDBM = DBManager()
|
|
|
|
mDBM.connect()
|
|
|
|
print(mDBM.start_task("11","22"))
|