import builtins
|
import copy
|
import hashlib
|
import json
|
import logging
|
import queue
|
import random
|
import socket
|
import socketserver
|
import threading
|
import time
|
|
import constant
|
import socket_manager
|
from db import mysql_data
|
from db.redis_manager import RedisUtils, RedisManager
|
from log_module import log
|
from log_module.log import logger_debug
|
from middle_l1_data_server import L1DataManager
|
from output import push_msg_manager
|
from utils import socket_util, kpl_api_util, hosting_api_util, kp_client_msg_manager, global_data_cache_util, tool, \
|
block_web_api
|
from utils.juejin_util import JueJinHttpApi
|
|
trade_data_request_queue = queue.Queue()
|
|
__mysql_config_dict = {}
|
|
|
def get_mysql_config(db_name):
|
"""
|
获取mysql的配置
|
:param db_name:
|
:return:
|
"""
|
if db_name in __mysql_config_dict:
|
return __mysql_config_dict.get(db_name)
|
config = copy.deepcopy(constant.MYSQL_CONFIG)
|
config["database"] = db_name
|
__mysql_config_dict[db_name] = config
|
return config
|
|
|
class MyTCPServer(socketserver.TCPServer):
|
def __init__(self, server_address, RequestHandlerClass):
|
socketserver.TCPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate=True)
|
|
|
# 如果使用异步的形式则需要再重写ThreadingTCPServer
|
class MyThreadingTCPServer(socketserver.ThreadingMixIn, MyTCPServer): pass
|
|
|
class MyBaseRequestHandle(socketserver.BaseRequestHandler):
|
__inited = False
|
|
def setup(self):
|
self.__init()
|
|
@classmethod
|
def __init(cls):
|
if cls.__inited:
|
return True
|
cls.__inited = True
|
cls.__req_socket_dict = {}
|
|
def __is_sign_right(self, data_json):
|
list_str = []
|
sign = data_json["sign"]
|
data_json.pop("sign")
|
for k in data_json:
|
list_str.append(f"{k}={data_json[k]}")
|
list_str.sort()
|
__str = "&".join(list_str) + "JiaBei@!*."
|
md5 = hashlib.md5(__str.encode(encoding='utf-8')).hexdigest()
|
if md5 != sign:
|
raise Exception("签名出错")
|
|
@classmethod
|
def getRecvData(cls, skk):
|
data = ""
|
header_size = 10
|
buf = skk.recv(header_size)
|
header_str = buf
|
if buf:
|
start_time = time.time()
|
buf = buf.decode('utf-8')
|
if buf.startswith("##"):
|
content_length = int(buf[2:10])
|
received_size = 0
|
while not received_size == content_length:
|
r_data = skk.recv(10240)
|
received_size += len(r_data)
|
data += r_data.decode('utf-8')
|
else:
|
data = skk.recv(1024 * 1024)
|
data = buf + data.decode('utf-8')
|
return data, header_str
|
|
def handle(self):
|
host = self.client_address[0]
|
super().handle()
|
sk: socket.socket = self.request
|
while True:
|
try:
|
data, header = self.getRecvData(sk)
|
if data:
|
data_str = data
|
# print("收到数据------", f"{data_str[:20]}......{data_str[-20:]}")
|
data_json = None
|
try:
|
data_json = json.loads(data_str)
|
except json.decoder.JSONDecodeError as e:
|
# JSON解析失败
|
sk.sendall(socket_util.load_header(json.dumps(
|
{"code": 100, "msg": f"JSON解析失败"}).encode(
|
encoding='utf-8')))
|
continue
|
type_ = data_json["type"]
|
__start_time = time.time()
|
try:
|
if data_json["type"] == 'register':
|
client_type = data_json["data"]["client_type"]
|
rid = data_json["rid"]
|
socket_manager.ClientSocketManager.add_client(client_type, rid, sk)
|
sk.sendall(json.dumps({"type": "register"}).encode(encoding='utf-8'))
|
try:
|
# print("客户端", ClientSocketManager.socket_client_dict)
|
while True:
|
result, header = self.getRecvData(sk)
|
try:
|
resultJSON = json.loads(result)
|
if resultJSON["type"] == 'heart':
|
# 记录活跃客户端
|
socket_manager.ClientSocketManager.heart(resultJSON['client_id'])
|
except json.decoder.JSONDecodeError as e:
|
print("JSON解析出错", result, header)
|
if not result:
|
sk.close()
|
break
|
time.sleep(1)
|
except ConnectionResetError as ee:
|
socket_manager.ClientSocketManager.del_client(rid)
|
except Exception as e:
|
logging.exception(e)
|
elif data_json["type"] == "response":
|
# 主动触发的响应
|
try:
|
client_id = data_json["client_id"]
|
# hx_logger_trade_callback.info(f"response:request_id-{data_json['request_id']}")
|
# # 设置响应内容
|
hosting_api_util.set_response(client_id, data_json["request_id"], data_json['data'])
|
finally:
|
sk.sendall(socket_util.load_header(json.dumps({"code": 0}).encode(encoding='utf-8')))
|
elif data_json["type"] == "l2_subscript_codes":
|
# 设置订阅的代码
|
try:
|
data = data_json["data"]
|
datas = data["data"]
|
# print("l2_subscript_codes", data_json)
|
global_data_cache_util.huaxin_subscript_codes = datas
|
global_data_cache_util.huaxin_subscript_codes_update_time = tool.get_now_time_str()
|
finally:
|
sk.sendall(socket_util.load_header(json.dumps({"code": 0}).encode(encoding='utf-8')))
|
elif data_json["type"] == "l2_subscript_codes_rate":
|
# 设置订阅的代码的涨幅
|
try:
|
data = data_json["data"]
|
datas = data["data"]
|
# print("l2_subscript_codes", data_json)
|
global_data_cache_util.huaxin_subscript_codes_rate = datas
|
finally:
|
sk.sendall(socket_util.load_header(json.dumps({"code": 0}).encode(encoding='utf-8')))
|
elif data_json["type"] == "l2_position_subscript_codes":
|
# 设置订阅的代码
|
try:
|
data = data_json["data"]
|
datas = data["data"]
|
print("l2_position_subscript_codes", data_json)
|
global_data_cache_util.huaxin_position_subscript_codes = datas
|
global_data_cache_util.huaxin_position_subscript_codes_update_time = tool.get_now_time_str()
|
finally:
|
sk.sendall(socket_util.load_header(json.dumps({"code": 0}).encode(encoding='utf-8')))
|
elif data_json["type"] == "redis":
|
try:
|
data = data_json["data"]
|
ctype = data["ctype"]
|
|
result_str = ''
|
if ctype == "queue_size":
|
# TODO 设置队列大小
|
result_str = json.dumps({"code": 0})
|
elif ctype == "cmd":
|
data = data["data"]
|
db = data["db"]
|
cmd = data["cmd"]
|
key = data["key"]
|
args = data.get("args")
|
redis = RedisManager(db).getRedis()
|
method = getattr(RedisUtils, cmd)
|
args_ = [redis, key]
|
if args is not None:
|
if builtins.type(args) == tuple or builtins.type(args) == list:
|
args = list(args)
|
if cmd == "setex":
|
args_.append(args[0])
|
if type(args[1]) == list:
|
args_.append(json.dumps(args[1]))
|
else:
|
args_.append(args[1])
|
else:
|
for a in args:
|
args_.append(a)
|
else:
|
args_.append(args)
|
args_ = tuple(args_)
|
result = method(*args_)
|
if builtins.type(result) == set:
|
result = list(result)
|
result_str = json.dumps({"code": 0, "data": result})
|
elif ctype == "cmds":
|
datas = data["data"]
|
result_list = []
|
for d in datas:
|
db = d["db"]
|
cmd = d["cmd"]
|
key = d["key"]
|
args = d.get("args")
|
redis = RedisManager(db).getRedis()
|
method = getattr(RedisUtils, cmd)
|
args_ = [redis, key]
|
if args is not None:
|
if builtins.type(args) == tuple or builtins.type(args) == list:
|
args = list(args)
|
if cmd == "setex":
|
args_.append(args[0])
|
if type(args[1]) == list:
|
args_.append(json.dumps(args[1]))
|
else:
|
args_.append(args[1])
|
else:
|
for a in args:
|
args_.append(a)
|
else:
|
args_.append(args)
|
args_ = tuple(args_)
|
result = method(*args_)
|
if builtins.type(result) == set:
|
result = list(result)
|
result_list.append(result)
|
result_str = json.dumps({"code": 0, "data": result_list})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
except Exception as e:
|
logger_debug.exception(e)
|
logger_debug.info(f"Redis操作出错:data_json:{data_json}")
|
logging.exception(e)
|
sk.sendall(socket_util.load_header(
|
json.dumps({"code": 1, "msg": str(e)}).encode(encoding='utf-8')))
|
elif data_json["type"] == "mysql":
|
try:
|
data = data_json["data"]
|
data = data["data"]
|
db = data["db"]
|
cmd = data["cmd"]
|
args = data.get("args")
|
mysql_config = get_mysql_config(db)
|
mysql = mysql_data.Mysqldb(mysql_config)
|
method = getattr(mysql, cmd)
|
args_ = []
|
if args:
|
if builtins.type(args) == tuple or builtins.type(args) == list:
|
args_ = list(args)
|
else:
|
args_.append(args)
|
args_ = tuple(args_)
|
result = method(*args_)
|
result_str = json.dumps({"code": 0, "data": result}, default=str)
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
except Exception as e:
|
logging.exception(e)
|
sk.sendall(socket_util.load_header(
|
json.dumps({"code": 1, "msg": str(e)}).encode(encoding='utf-8')))
|
elif data_json["type"] == "juejin":
|
# 掘金请求
|
try:
|
data = data_json["data"]
|
data = data["data"]
|
path_ = data["path"]
|
params = data.get("params")
|
result = JueJinHttpApi.request(path_, params)
|
result_str = json.dumps({"code": 0, "data": result})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
except Exception as e:
|
logging.exception(e)
|
sk.sendall(socket_util.load_header(
|
json.dumps({"code": 1, "msg": str(e)}).encode(encoding='utf-8')))
|
elif data_json["type"] == "kpl":
|
# 开盘啦请求
|
try:
|
data = data_json["data"]
|
data = data["data"]
|
url = data["url"]
|
data_ = data.get("data")
|
result = kpl_api_util.request(url, data_)
|
result_str = json.dumps({"code": 0, "data": result})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
except Exception as e:
|
logging.exception(e)
|
sk.sendall(socket_util.load_header(
|
json.dumps({"code": 1, "msg": str(e)}).encode(encoding='utf-8')))
|
elif data_json["type"] == "kp_msg":
|
# 看盘消息
|
data = data_json["data"]
|
data = data["data"]
|
msg = data["msg"]
|
kp_client_msg_manager.add_msg(msg)
|
result_str = json.dumps({"code": 0, "data": {}})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
pass
|
elif data_json["type"] == "push_msg":
|
data = data_json["data"]["data"]
|
_type = data["type"]
|
data = data.get("data")
|
logger_debug.info(f"推送消息:{data_json}")
|
push_msg_manager.push_msg(_type, data)
|
result_str = json.dumps({"code": 0, "data": {}})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
elif data_json["type"] == 'l1_data':
|
datas = data_json["data"]
|
L1DataManager().add_datas(datas)
|
break
|
elif data_json["type"] == 'get_l1_target_codes':
|
# 获取目标代码
|
codes = L1DataManager().get_target_codes()
|
result_str = json.dumps({"code": 0, "data": list(codes)})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
break
|
# 获取三方板块
|
elif data_json["type"] == 'get_third_blocks':
|
data = data_json["data"]
|
data = data["data"]
|
source = data["source"]
|
code = data["code"]
|
result_str = json.dumps({"code": 1, "msg": "source不匹配"})
|
if source == 2:
|
# 通达信
|
try:
|
blocks = block_web_api.get_tdx_blocks(code)
|
result_str = json.dumps({"code": 0, "data": list(blocks)})
|
except Exception as e:
|
result_str = json.dumps({"code": 1, "msg": str(e)})
|
elif source == 3:
|
# 同花顺
|
try:
|
blocks = block_web_api.THSBlocksApi().get_ths_blocks(code)
|
result_str = json.dumps({"code": 0, "data": list(blocks)})
|
except Exception as e:
|
try:
|
block_web_api.THSBlocksApi.load_hexin_v()
|
blocks = block_web_api.THSBlocksApi().get_ths_blocks(code)
|
result_str = json.dumps({"code": 0, "data": list(blocks)})
|
except Exception as e1:
|
result_str = json.dumps({"code": 1, "msg": str(e1)})
|
elif source == 4:
|
# 东方财富
|
try:
|
blocks = block_web_api.get_eastmoney_block(code)
|
result_str = json.dumps({"code": 0, "data": list(blocks)})
|
except Exception as e:
|
result_str = json.dumps({"code": 1, "msg": str(e)})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
break
|
elif data_json["type"] == 'low_suction':
|
# TODO 低吸通道
|
datas = data_json["data"]
|
pass
|
|
|
|
except Exception as e:
|
log.logger_tuoguan_request_debug.exception(e)
|
finally:
|
if time.time() - __start_time > 2:
|
log.logger_tuoguan_request_debug.info(
|
f"耗时:{int(time.time() - __start_time)}s 数据:{data_json}")
|
else:
|
# 断开连接
|
break
|
# sk.close()
|
except Exception as e:
|
# log.logger_tuoguan_request_debug.exception(e)
|
logging.exception(e)
|
break
|
|
def finish(self):
|
super().finish()
|
|
|
def clear_invalid_client():
|
while True:
|
try:
|
socket_manager.ClientSocketManager.del_invalid_clients()
|
except:
|
pass
|
finally:
|
time.sleep(2)
|
|
|
def __recv_pipe_l1(pipe_trade, pipe_l1):
|
if pipe_trade is not None and pipe_l1 is not None:
|
while True:
|
try:
|
val = pipe_l1.recv()
|
if val:
|
val = json.loads(val)
|
print("收到来自L1的数据:", val)
|
# 处理数据
|
except:
|
pass
|
|
|
def run(port=constant.MIDDLE_SERVER_PORT):
|
print("create MiddleServer")
|
t1 = threading.Thread(target=lambda: clear_invalid_client(), daemon=True)
|
t1.start()
|
|
laddr = "0.0.0.0", port
|
print("MiddleServer is at: http://%s:%d/" % (laddr))
|
tcpserver = MyThreadingTCPServer(laddr, MyBaseRequestHandle) # 注意:参数是MyBaseRequestHandle
|
tcpserver.serve_forever()
|
|
|
if __name__ == "__main__":
|
pass
|