import hashlib
|
import json
|
import random
|
import socket
|
import socketserver
|
import threading
|
import time
|
|
import trade_server_processor
|
|
|
class ClientSocketManager:
|
# 客户端类型
|
CLIENT_TYPE_TRADE = "trade"
|
CLIENT_TYPE_DELEGATE_LIST = "delegate_list"
|
CLIENT_TYPE_DEAL_LIST = "deal_list"
|
CLIENT_TYPE_POSITION_LIST = "position_list"
|
CLIENT_TYPE_MONEY = "money"
|
CLIENT_TYPE_DEAL = "deal"
|
CLIENT_TYPE_CMD_L2 = "l2_cmd"
|
socket_client_dict = {}
|
socket_client_lock_dict = {}
|
|
@classmethod
|
def add_client(cls, _type, rid, sk):
|
if _type == cls.CLIENT_TYPE_TRADE:
|
# 交易列表
|
if _type not in cls.socket_client_dict:
|
cls.socket_client_dict[_type] = []
|
cls.socket_client_dict[_type].append((rid, sk))
|
cls.socket_client_lock_dict[rid] = threading.Lock()
|
else:
|
cls.socket_client_dict[_type] = (rid, sk)
|
cls.socket_client_lock_dict[rid] = threading.Lock()
|
print(cls.socket_client_dict)
|
|
@classmethod
|
def acquire_client(cls, _type):
|
if _type == cls.CLIENT_TYPE_TRADE:
|
if _type in cls.socket_client_dict:
|
for d in cls.socket_client_dict[_type]:
|
if d[0] in cls.socket_client_lock_dict:
|
try:
|
if cls.socket_client_lock_dict[d[0]].acquire(blocking=False):
|
return d
|
except threading.TimeoutError:
|
pass
|
else:
|
if _type in cls.socket_client_dict:
|
try:
|
d = cls.socket_client_dict[_type]
|
if d[0] in cls.socket_client_lock_dict:
|
if cls.socket_client_lock_dict[d[0]].acquire(blocking=False):
|
return d
|
except threading.TimeoutError:
|
pass
|
return None
|
|
@classmethod
|
def release_client(cls, rid):
|
if rid in cls.socket_client_lock_dict:
|
# 释放锁
|
cls.socket_client_lock_dict[rid].release()
|
|
@classmethod
|
def del_client(cls, rid):
|
# 删除线程锁
|
if rid in cls.socket_client_lock_dict:
|
cls.socket_client_lock_dict.pop(rid)
|
# 删除sk
|
for t in cls.socket_client_dict:
|
if type(cls.socket_client_dict[t]) == list:
|
for d in cls.socket_client_dict[t]:
|
if d[0] == rid:
|
cls.socket_client_dict[t].remove(d)
|
break
|
|
elif type(cls.socket_client_dict[t]) == tuple:
|
if cls.socket_client_dict[t][0] == rid:
|
cls.socket_client_dict.pop(t)
|
break
|
|
|
class MyTCPServer(socketserver.TCPServer):
|
def __init__(self, server_address, RequestHandlerClass):
|
socketserver.TCPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate=True)
|
|
|
# 如果使用异步的形式则需要再重写ThreadingTCPServer
|
class MyThreadingTCPServer(socketserver.ThreadingMixIn, MyTCPServer): pass
|
|
|
class MyBaseRequestHandle(socketserver.BaseRequestHandler):
|
__inited = False
|
|
def setup(self):
|
self.__init()
|
|
@classmethod
|
def __init(cls):
|
if cls.__inited:
|
return True
|
cls.__inited = True
|
cls.__req_socket_dict = {}
|
|
def __is_sign_right(self, data_json):
|
list_str = []
|
sign = data_json["sign"]
|
data_json.pop("sign")
|
for k in data_json:
|
list_str.append(f"{k}={data_json[k]}")
|
list_str.sort()
|
__str = "&".join(list_str) + "JiaBei@!*."
|
md5 = hashlib.md5(__str.encode(encoding='utf-8')).hexdigest()
|
if md5 != sign:
|
raise Exception("签名出错")
|
|
def handle(self):
|
host = self.client_address[0]
|
super().handle()
|
sk: socket.socket = self.request
|
while True:
|
try:
|
data = sk.recv(1024 * 100)
|
if data:
|
data_str = str(data, encoding="utf-8")
|
print("收到数据------", data_str)
|
data_json = json.loads(data_str)
|
if data_json["type"] == 'register':
|
client_type = data_json["data"]["client_type"]
|
rid = data_json["rid"]
|
ClientSocketManager.add_client(client_type, rid, sk)
|
sk.send(json.dumps({"type": "register"}).encode(encoding='utf-8'))
|
sk.recv(1024 * 100)
|
break
|
else:
|
result = trade_server_processor.process(data_json["data"])
|
sk.send(json.dumps({"code": 0}).encode(encoding='utf-8'))
|
|
# sk.close()
|
except:
|
pass
|
|
def finish(self):
|
super().finish()
|
|
|
def run():
|
laddr = "", 10008
|
tcpserver = MyThreadingTCPServer(laddr, MyBaseRequestHandle) # 注意:参数是MyBaseRequestHandle
|
# tcpserver.handle_request() # 只接受一个客户端连接
|
tcpserver.serve_forever()
|
|
|
def test1():
|
r = (ClientSocketManager.acquire_client(ClientSocketManager.CLIENT_TYPE_TRADE))
|
print("test1", r)
|
time.sleep(random.randint(0, 3))
|
if r:
|
ClientSocketManager.release_client(r[0])
|
|
|
def test2():
|
time.sleep(random.randint(0, 3))
|
print("test2", ClientSocketManager.acquire_client(ClientSocketManager.CLIENT_TYPE_TRADE))
|
|
|
if __name__ == "__main__":
|
run()
|
# ClientSocketManager.add_client(ClientSocketManager.CLIENT_TYPE_TRADE, "1", None)
|
# ClientSocketManager.add_client(ClientSocketManager.CLIENT_TYPE_TRADE, "2", None)
|
# ClientSocketManager.add_client(ClientSocketManager.CLIENT_TYPE_TRADE, "3", None)
|
#
|
# for i in range(0, 3):
|
# t1 = threading.Thread(target=lambda: test1())
|
# t1.setDaemon(True)
|
# t1.start()
|
#
|
# for i in range(0, 3):
|
# t1 = threading.Thread(target=lambda: test2())
|
# t1.setDaemon(True)
|
# t1.start()
|
#
|
input()
|