import hashlib
|
import json
|
import logging
|
import queue
|
import random
|
import socket
|
import socketserver
|
import threading
|
import time
|
|
import constant
|
import socket_manager
|
from db import mysql_data
|
from db.redis_manager import RedisUtils, RedisManager
|
from log import logger_debug, logger_request_debug
|
from utils import socket_util, kpl_api_util, hosting_api_util, kp_client_msg_manager, global_data_cache_util, tool
|
from utils.juejin_util import JueJinHttpApi
|
|
trade_data_request_queue = queue.Queue()
|
|
|
class MyTCPServer(socketserver.TCPServer):
|
def __init__(self, server_address, RequestHandlerClass):
|
socketserver.TCPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate=True)
|
|
|
# 如果使用异步的形式则需要再重写ThreadingTCPServer
|
class MyThreadingTCPServer(socketserver.ThreadingMixIn, MyTCPServer): pass
|
|
|
class MyBaseRequestHandle(socketserver.BaseRequestHandler):
|
__inited = False
|
|
def setup(self):
|
self.__init()
|
|
@classmethod
|
def __init(cls):
|
if cls.__inited:
|
return True
|
cls.__inited = True
|
cls.__req_socket_dict = {}
|
|
def __is_sign_right(self, data_json):
|
list_str = []
|
sign = data_json["sign"]
|
data_json.pop("sign")
|
for k in data_json:
|
list_str.append(f"{k}={data_json[k]}")
|
list_str.sort()
|
__str = "&".join(list_str) + "JiaBei@!*."
|
md5 = hashlib.md5(__str.encode(encoding='utf-8')).hexdigest()
|
if md5 != sign:
|
raise Exception("签名出错")
|
|
@classmethod
|
def getRecvData(cls, skk):
|
data = ""
|
header_size = 10
|
buf = skk.recv(header_size)
|
header_str = buf
|
if buf:
|
start_time = time.time()
|
buf = buf.decode('utf-8')
|
if buf.startswith("##"):
|
content_length = int(buf[2:10])
|
received_size = 0
|
while not received_size == content_length:
|
r_data = skk.recv(10240)
|
received_size += len(r_data)
|
data += r_data.decode('utf-8')
|
else:
|
data = skk.recv(1024 * 1024)
|
data = buf + data.decode('utf-8')
|
return data, header_str
|
|
def handle(self):
|
host = self.client_address[0]
|
super().handle()
|
sk: socket.socket = self.request
|
while True:
|
try:
|
data, header = self.getRecvData(sk)
|
if data:
|
data_str = data
|
# print("收到数据------", f"{data_str[:20]}......{data_str[-20:]}")
|
data_json = None
|
try:
|
data_json = json.loads(data_str)
|
except json.decoder.JSONDecodeError as e:
|
# JSON解析失败
|
sk.sendall(socket_util.load_header(json.dumps(
|
{"code": 100, "msg": f"JSON解析失败"}).encode(
|
encoding='utf-8')))
|
continue
|
thread_id = random.randint(0, 1000000)
|
logger_request_debug.info(f"middle_server 请求开始({thread_id}):{data_json.get('type')}")
|
try:
|
if data_json["type"] == 'register':
|
client_type = data_json["data"]["client_type"]
|
rid = data_json["rid"]
|
socket_manager.ClientSocketManager.add_client(client_type, rid, sk)
|
sk.sendall(json.dumps({"type": "register"}).encode(encoding='utf-8'))
|
try:
|
# print("客户端", ClientSocketManager.socket_client_dict)
|
while True:
|
result, header = self.getRecvData(sk)
|
try:
|
resultJSON = json.loads(result)
|
if resultJSON["type"] == 'heart':
|
# 记录活跃客户端
|
socket_manager.ClientSocketManager.heart(resultJSON['client_id'])
|
except json.decoder.JSONDecodeError as e:
|
print("JSON解析出错", result, header)
|
if not result:
|
sk.close()
|
break
|
time.sleep(1)
|
except ConnectionResetError as ee:
|
socket_manager.ClientSocketManager.del_client(rid)
|
except Exception as e:
|
logging.exception(e)
|
elif data_json["type"] == "response":
|
# 主动触发的响应
|
try:
|
client_id = data_json["client_id"]
|
# hx_logger_trade_callback.info(f"response:request_id-{data_json['request_id']}")
|
# # 设置响应内容
|
hosting_api_util.set_response(client_id, data_json["request_id"], data_json['data'])
|
finally:
|
sk.sendall(socket_util.load_header(json.dumps({"code": 0}).encode(encoding='utf-8')))
|
elif data_json["type"] == "l2_subscript_codes":
|
# 设置订阅的代码
|
try:
|
data = data_json["data"]
|
datas = data["data"]
|
print("l2_subscript_codes", data_json)
|
global_data_cache_util.huaxin_subscript_codes = datas
|
global_data_cache_util.huaxin_subscript_codes_update_time = tool.get_now_time_str()
|
finally:
|
sk.sendall(socket_util.load_header(json.dumps({"code": 0}).encode(encoding='utf-8')))
|
elif data_json["type"] == "redis":
|
try:
|
data = data_json["data"]
|
ctype = data["ctype"]
|
|
result_str = ''
|
if ctype == "queue_size":
|
# TODO 设置队列大小
|
result_str = json.dumps({"code": 0})
|
elif ctype == "cmd":
|
data = data["data"]
|
db = data["db"]
|
cmd = data["cmd"]
|
key = data["key"]
|
args = data.get("args")
|
redis = RedisManager(db).getRedis()
|
method = getattr(RedisUtils, cmd)
|
args_ = [redis, key]
|
if args is not None:
|
if type(args) == tuple or type(args) == list:
|
args = list(args)
|
for a in args:
|
args_.append(a)
|
else:
|
args_.append(args)
|
args_ = tuple(args_)
|
result = method(*args_)
|
if type(result) == set:
|
result = list(result)
|
result_str = json.dumps({"code": 0, "data": result})
|
elif ctype == "cmds":
|
datas = data["data"]
|
result_list=[]
|
for d in datas:
|
db = d["db"]
|
cmd = d["cmd"]
|
key = d["key"]
|
args = d.get("args")
|
redis = RedisManager(db).getRedis()
|
method = getattr(RedisUtils, cmd)
|
args_ = [redis, key]
|
if args is not None:
|
if type(args) == tuple or type(args) == list:
|
args = list(args)
|
for a in args:
|
args_.append(a)
|
else:
|
args_.append(args)
|
args_ = tuple(args_)
|
result = method(*args_)
|
if type(result) == set:
|
result = list(result)
|
result_list.append(result)
|
result_str = json.dumps({"code": 0, "data": result_list})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
except Exception as e:
|
logger_debug.exception(e)
|
logger_debug.info(f"Redis操作出错:data_json:{data_json}")
|
logging.exception(e)
|
sk.sendall(socket_util.load_header(
|
json.dumps({"code": 1, "msg": str(e)}).encode(encoding='utf-8')))
|
elif data_json["type"] == "mysql":
|
try:
|
data = data_json["data"]
|
data = data["data"]
|
db = data["db"]
|
cmd = data["cmd"]
|
args = data.get("args")
|
mysql = mysql_data.Mysqldb()
|
method = getattr(mysql, cmd)
|
args_ = []
|
if args:
|
if type(args) == tuple or type(args) == list:
|
args_ = list(args)
|
else:
|
args_.append(args)
|
args_ = tuple(args_)
|
result = method(*args_)
|
result_str = json.dumps({"code": 0, "data": result}, default=str)
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
except Exception as e:
|
logging.exception(e)
|
sk.sendall(socket_util.load_header(
|
json.dumps({"code": 1, "msg": str(e)}).encode(encoding='utf-8')))
|
elif data_json["type"] == "juejin":
|
# 掘金请求
|
try:
|
data = data_json["data"]
|
data = data["data"]
|
path_ = data["path"]
|
params = data.get("params")
|
result = JueJinHttpApi.request(path_, params)
|
result_str = json.dumps({"code": 0, "data": result})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
except Exception as e:
|
logging.exception(e)
|
sk.sendall(socket_util.load_header(
|
json.dumps({"code": 1, "msg": str(e)}).encode(encoding='utf-8')))
|
elif data_json["type"] == "kpl":
|
# 开盘啦请求
|
try:
|
data = data_json["data"]
|
data = data["data"]
|
url = data["url"]
|
data_ = data.get("data")
|
result = kpl_api_util.request(url, data_)
|
result_str = json.dumps({"code": 0, "data": result})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
except Exception as e:
|
logging.exception(e)
|
sk.sendall(socket_util.load_header(
|
json.dumps({"code": 1, "msg": str(e)}).encode(encoding='utf-8')))
|
elif data_json["type"] == "kp_msg":
|
# 看盘消息
|
data = data_json["data"]
|
data = data["data"]
|
msg = data["msg"]
|
kp_client_msg_manager.add_msg(msg)
|
result_str = json.dumps({"code": 0, "data": {}})
|
sk.sendall(socket_util.load_header(result_str.encode(encoding='utf-8')))
|
pass
|
finally:
|
logger_request_debug.info(f"middle_server 请求结束({thread_id}):{data_json.get('type')}")
|
else:
|
# 断开连接
|
break
|
# sk.close()
|
except Exception as e:
|
logging.exception(e)
|
break
|
|
def finish(self):
|
super().finish()
|
|
|
def clear_invalid_client():
|
while True:
|
try:
|
socket_manager.ClientSocketManager.del_invalid_clients()
|
except:
|
pass
|
finally:
|
time.sleep(2)
|
|
|
def __recv_pipe_l1(pipe_trade, pipe_l1):
|
if pipe_trade is not None and pipe_l1 is not None:
|
while True:
|
try:
|
val = pipe_l1.recv()
|
if val:
|
val = json.loads(val)
|
print("收到来自L1的数据:", val)
|
# 处理数据
|
except:
|
pass
|
|
|
def run():
|
print("create MiddleServer")
|
t1 = threading.Thread(target=lambda: clear_invalid_client(), daemon=True)
|
t1.start()
|
|
laddr = "0.0.0.0", constant.MIDDLE_SERVER_PORT
|
print("MiddleServer is at: http://%s:%d/" % (laddr))
|
tcpserver = MyThreadingTCPServer(laddr, MyBaseRequestHandle) # 注意:参数是MyBaseRequestHandle
|
tcpserver.serve_forever()
|
|
|
if __name__ == "__main__":
|
pass
|