""" 三方板块管理 """ from itertools import combinations from db.mysql_data_delegate import Mysqldb from utils import middle_api_protocol from utils.kpl_data_db_util import KPLLimitUpDataUtil from utils.ths_industry_util import ThsCodeIndustryManager SOURCE_TYPE_KPL = 1 # 开盘啦 SOURCE_TYPE_TDX = 2 # 通达信 SOURCE_TYPE_THS = 3 # 同花顺 SOURCE_TYPE_EASTMONEY = 4 # 东方财富 SOURCE_TYPE_KPL_RECORD = 5 # 开盘啦历史数据 class CodeThirdBlocksManager: __instance = None __mysql = Mysqldb() """ 代码的三方板块管理 """ # 代码板块:{code:{1:{"b1","b2"},2:{"c1","c2"}}} __code_source_blocks_dict = {} __code_source_blocks_dict_origin = {} __ths_industry = ThsCodeIndustryManager() def __new__(cls, *args, **kwargs): if not cls.__instance: cls.__instance = super(CodeThirdBlocksManager, cls).__new__(cls, *args, **kwargs) cls.__load_data() return cls.__instance @classmethod def __load_data(cls): results = cls.__mysql.select_all("select _code, _source, _blocks from code_third_blocks") cls.__code_source_blocks_dict.clear() for result in results: if result[0] not in cls.__code_source_blocks_dict: cls.__code_source_blocks_dict[result[0]] = {} cls.__code_source_blocks_dict_origin[result[0]] = {} blocks = set(result[2].split("、")) if result[1] == SOURCE_TYPE_THS: # 同花顺加入2级分类 industry = cls.__ths_industry.get_industry(result[0]) if industry: blocks.add(industry) cls.__code_source_blocks_dict_origin[result[0]][result[1]] = blocks cls.__code_source_blocks_dict[result[0]][result[1]] = BlockMapManager().filter_blocks(blocks) # 加载开盘啦历史涨停原因 kpl_results = KPLLimitUpDataUtil.get_latest_block_infos() code_blocks = {} for r in kpl_results: if r[0] not in code_blocks: code_blocks[r[0]] = set() code_blocks[r[0]].add(r[2]) if r[3]: code_blocks[r[0]] |= set(r[3].split("、")) for code in code_blocks: if code not in cls.__code_source_blocks_dict: cls.__code_source_blocks_dict[code] = {} cls.__code_source_blocks_dict_origin[code] = {} blocks = code_blocks[code] cls.__code_source_blocks_dict_origin[code][SOURCE_TYPE_KPL_RECORD] = blocks cls.__code_source_blocks_dict[code][SOURCE_TYPE_KPL_RECORD] = BlockMapManager().filter_blocks(blocks) def get_source_blocks(self, code): """ 根据代码获取到来源板块映射 @param code: @return: """ return self.__code_source_blocks_dict.get(code) def get_source_blocks_origin(self, code): """ 根据代码获取到来源板块映射 @param code: @return: """ return self.__code_source_blocks_dict_origin.get(code) def get_intersection_blocks_info(self, code, blocks, same_count=2): # 获取交集 bs = [] b1 = BlockMapManager().filter_blocks(blocks) if b1: bs.append(b1) sb_dict = self.__code_source_blocks_dict.get(code) if sb_dict: for s in sb_dict: if sb_dict[s]: bs.append(sb_dict[s]) if len(bs) < same_count: return set(), bs s_count = len(bs) fblocks = set() # 求2个平台的交集 for ces in combinations(bs, same_count): ic = None for c in ces: if ic is None: ic = set(c) ic &= c if ic: fblocks |= ic return fblocks, bs def set_blocks(self, code, blocks, source_type): """ 设置代码来源的板块 @param code: @param blocks: @param source_type: @return: """ # 更新缓存数据 if code not in self.__code_source_blocks_dict: self.__code_source_blocks_dict[code] = {} if code not in self.__code_source_blocks_dict_origin: self.__code_source_blocks_dict_origin[code] = {} if blocks: self.__code_source_blocks_dict[code][source_type] = BlockMapManager().filter_blocks(set(blocks)) self.__code_source_blocks_dict_origin[code][source_type] = set(blocks) _id = f"{code}_{source_type}" blocks = "、".join(blocks) results = self.__mysql.select_one(f"select * from code_third_blocks where _id ='{_id}'") if results: # 更新数据 self.__mysql.execute( f"update code_third_blocks set _blocks = '{blocks}', _update_time = now() where _id ='{_id}'") else: self.__mysql.execute( "insert into code_third_blocks(_id, _code, _source, _blocks,_create_time) values('{}','{}',{},'{}',now())".format( _id, code, source_type, blocks)) def list(self, code=None, source_type=None, max_update_time=None): sql = f"select * from code_third_blocks where 1=1 " if code: sql += f" and _code = '{code}'" if source_type is not None: sql += f" and _source={source_type}" if max_update_time: sql += f" and (_update_time is null or _update_time<'{max_update_time}')" return self.__mysql.select_all(sql) def list_all_blocks(self, source_type): sql = f"select _blocks from code_third_blocks where _source = '{source_type}'" return self.__mysql.select_all(sql) class BlockMapManager: """ 板块映射管理 """ __mysql = Mysqldb() __instance = None __block_map = {} def __new__(cls, *args, **kwargs): if not cls.__instance: cls.__instance = super(BlockMapManager, cls).__new__(cls, *args, **kwargs) cls.__load_data() return cls.__instance @classmethod def __load_data(cls): results = cls.__mysql.select_all("select origin_block,blocks from block_map") cls.__block_map.clear() for result in results: cls.__block_map[result[0]] = set(result[1].split("、")) def set_block_map(self, origin_block, blocks): if not blocks: blocks = {origin_block} blocks_str = "、".join(blocks) result = self.__mysql.select_one(f"select * from block_map where origin_block='{origin_block}'") if result: # 更新 self.__mysql.execute( f"update block_map set blocks='{blocks_str}', update_time=now() where origin_block='{origin_block}'") else: self.__mysql.execute( f"insert into block_map(origin_block, blocks, create_time) values('{origin_block}','{blocks_str}', now())") def get_map_blocks_cache(self, block): """ 获取映射好的板块 @param block: @return: """ return self.__block_map.get(block) def filter_blocks(self, blocks): """ 批量过滤板块 @param blocks: @return: """ if blocks is None or len(blocks) == 0: return set() fbs = set() invalid_blocks = InvalidBlockManager().get_invalid_blocks() for block in blocks: if block.endswith("概念"): block = block[:-2] b = self.get_map_blocks_cache(block) if b: fbs |= b if block in invalid_blocks: continue fbs.add(block) return fbs def get_all_blocks(self): return self.__block_map.keys() class InvalidBlockManager: """ 无效板块管理 """ __mysql = Mysqldb() __instance = None __block = set() def __new__(cls, *args, **kwargs): if not cls.__instance: cls.__instance = super(InvalidBlockManager, cls).__new__(cls, *args, **kwargs) cls.__load_data() return cls.__instance @classmethod def __load_data(cls): results = cls.__mysql.select_all("select _block from invalid_block") cls.__block.clear() for result in results: if result[0]: cls.__block.add(result[0]) def get_invalid_blocks(self): """ 获取无效的板块 @return: """ return self.__block def set_incalid_blocks(self, blocks): """ 设置无效的板块 @param blocks: @return: """ # 先删除所有,然后再添加 self.__mysql.execute("delete from invalid_block") for b in blocks: self.__mysql.execute(f"insert into invalid_block(_block) values('{b}')") self.__block = set(blocks) def load_if_less(codes): """ 加载 @param codes: @return: """ for code in codes: source_blocks = CodeThirdBlocksManager().get_source_blocks_origin(code) if source_blocks is None: source_blocks = {} all_source = {SOURCE_TYPE_EASTMONEY, SOURCE_TYPE_TDX, SOURCE_TYPE_THS} sources = all_source - source_blocks.keys() for source in sources: try: blocks = middle_api_protocol.request(middle_api_protocol.get_third_blocks(code, source)) if blocks: CodeThirdBlocksManager().set_blocks(code, blocks, source) except: pass def __add_invlaid_blocks(): blocks_str = """ 昨日连板 昨日连板_含一字 昨日涨停 昨日涨停_含一字 """ blocks = set() for x in blocks_str.split("\n"): if x.strip(): blocks.add(x.strip()) print(len(blocks)) InvalidBlockManager().set_incalid_blocks(blocks) if __name__ == '__main__': pass