""" 同花顺行业工具 """ # 同花顺行业 import time from code_attribute import global_data_loader from db.mysql_data_delegate import Mysqldb from utils import global_util, tool from db import mysql_data_delegate as mysql_data # 获取行业映射 class ThsCodeIndustryManager: __instance = None __mysql = Mysqldb() __code_industry = {} def __new__(cls, *args, **kwargs): if not cls.__instance: cls.__instance = super(ThsCodeIndustryManager, cls).__new__(cls, *args, **kwargs) cls.__load_data() return cls.__instance @classmethod def __load_data(cls): results = cls.__mysql.select_all("select _id,second_industry from ths_industry_codes") if results: for r in results: code = r[0] industry = r[1] cls.__code_industry[code] = industry def get_industry(self, code): return self.__code_industry.get(code) def get_code_industry_maps(): __code_map = {} __industry_map = {} mysqldb = mysql_data.Mysqldb() results = mysqldb.select_all("select _id,second_industry from ths_industry_codes") if results: for r in results: code = r[0] industry = r[1] __code_map[code] = industry if __industry_map.get(industry) is None: __industry_map[industry] = set() __industry_map[industry].add(code) return __code_map, __industry_map # 设置行业热度 def set_industry_hot_num(limit_up_datas): if limit_up_datas is None: return industry_hot_dict = {} code_industry_map = global_util.code_industry_map if code_industry_map is None or len(code_industry_map) == 0: global_data_loader.load_industry() code_industry_map = global_util.code_industry_map if code_industry_map is None: raise Exception("获取代码对应的行业出错") now_str = tool.get_now_time_str() for data in limit_up_datas: # 时间比现在早的时间才算数 if data["time"] != "00:00:00" and tool.get_time_as_second(now_str) < tool.get_time_as_second( data["time"]): continue code = data["code"] industry = code_industry_map.get(code) if industry is None: # 获取代码对应的行业出错 continue if industry_hot_dict.get(industry) is None: industry_hot_dict.setdefault(industry, 0) percent = float(data["limitUpPercent"]) if percent > 21: percent = 21 percent = round(percent, 2) # 保存涨幅 global_util.limit_up_codes_percent[code] = percent industry_hot_dict[industry] = round(industry_hot_dict[industry] + percent, 2) global_util.industry_hot_num = industry_hot_dict # 获取相同行业的代码 # 返回:行业,同行业代码 def get_same_industry_codes(code, codes): industry = global_util.code_industry_map.get(code) if industry is None: global_data_loader.load_industry() industry = global_util.code_industry_map.get(code) if industry is None: return None, None codes_ = set() for code_ in codes: if global_util.code_industry_map.get(code_) == industry: # 同一行业 codes_.add(code_) return industry, codes_ # 获取这一批数据的行业 def __get_industry(datas): ors = [] codes = set() for data in datas: codes.add(data["code"]) " or ".join(codes) for code in codes: ors.append("first_code='{}'".format(code)) mysqldb = mysql_data.Mysqldb() results = mysqldb.select_all("select * from ths_industry where {}".format(" or ".join(ors))) _fname = None for a in results: _fname = a[0] break print("最终的二级行业名称为:", _fname) return _fname # 保存单个代码的行业 def __save_code_industry(code, code_name, industry_name, zyltgb, zyltgb_unit): mysqldb = mysql_data.Mysqldb() result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code)) if result is None: mysqldb.execute( "insert into ths_industry_codes(_id,_name, second_industry,zyltgb,zyltgb_unit) values('{}','{}','{}','{}',{})".format( code, code_name, industry_name, zyltgb, zyltgb_unit, round(time.time() * 1000))) else: if code_name: mysqldb.execute( "update ths_industry_codes set _name='{}', second_industry='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format( code_name, industry_name, zyltgb, zyltgb_unit, code)) else: mysqldb.execute( "update ths_industry_codes set second_industry='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format( industry_name, zyltgb, zyltgb_unit, code)) # 保存行业代码 def save_industry_code(datasList, code_names): for datas in datasList: # 查询这批数据所属行业 industry_name = __get_industry(datas) _list = [] for data in datas: # 保存 code = data["code"] __save_code_industry(code, code_names.get(code), industry_name, data["zyltgb"], data["zyltgb_unit"]) def save_code_industry(code, code_name, industry): __save_code_industry(code, code_name, industry, 0, 0) # 根据名称获取代码 def get_code_by_name(name): mysqldb = mysql_data.Mysqldb() result = mysqldb.select_one("select * from ths_industry_codes where _name='{}'".format(name)) if result is not None: return result[0] else: return None def get_name_by_code(code): mysqldb = mysql_data.Mysqldb() result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code)) if result is not None: return result[1] else: return None if __name__ == "__main__": _code_map, _industry_map = get_code_industry_maps() print(_code_map, _industry_map)