"""
|
同花顺行业工具
|
"""
|
|
# 同花顺行业
|
import time
|
|
from code_attribute import global_data_loader
|
from utils import global_util, tool
|
from db import mysql_data
|
|
# 获取行业映射
|
|
|
def get_code_industry_maps():
|
__code_map = {}
|
__industry_map = {}
|
mysqldb = mysql_data.Mysqldb()
|
results = mysqldb.select_all("select _id,second_industry from ths_industry_codes")
|
for r in results:
|
code = r[0]
|
industry = r[1]
|
__code_map[code] = industry
|
if __industry_map.get(industry) is None:
|
__industry_map[industry] = set()
|
__industry_map[industry].add(code)
|
return __code_map, __industry_map
|
|
|
# 设置行业热度
|
def set_industry_hot_num(limit_up_datas):
|
if limit_up_datas is None:
|
return
|
industry_hot_dict = {}
|
code_industry_map = global_util.code_industry_map
|
if code_industry_map is None or len(code_industry_map) == 0:
|
global_data_loader.load_industry()
|
code_industry_map = global_util.code_industry_map
|
if code_industry_map is None:
|
raise Exception("获取代码对应的行业出错")
|
|
now_str = tool.get_now_time_str()
|
for data in limit_up_datas:
|
# 时间比现在早的时间才算数
|
if data["time"] != "00:00:00" and tool.get_time_as_second(now_str) < tool.get_time_as_second(
|
data["time"]):
|
continue
|
|
code = data["code"]
|
industry = code_industry_map.get(code)
|
if industry is None:
|
# 获取代码对应的行业出错
|
continue
|
if industry_hot_dict.get(industry) is None:
|
industry_hot_dict.setdefault(industry, 0)
|
|
percent = float(data["limitUpPercent"])
|
if percent > 21:
|
percent = 21
|
percent = round(percent, 2)
|
# 保存涨幅
|
global_util.limit_up_codes_percent[code] = percent
|
|
industry_hot_dict[industry] = round(industry_hot_dict[industry] + percent, 2)
|
|
global_util.industry_hot_num = industry_hot_dict
|
|
|
# 获取相同行业的代码
|
# 返回:行业,同行业代码
|
def get_same_industry_codes(code, codes):
|
industry = global_util.code_industry_map.get(code)
|
if industry is None:
|
global_data_loader.load_industry()
|
industry = global_util.code_industry_map.get(code)
|
if industry is None:
|
return None, None
|
codes_ = set()
|
for code_ in codes:
|
if global_util.code_industry_map.get(code_) == industry:
|
# 同一行业
|
codes_.add(code_)
|
return industry, codes_
|
|
|
# 获取这一批数据的行业
|
def __get_industry(datas):
|
ors = []
|
codes = set()
|
for data in datas:
|
codes.add(data["code"])
|
|
" or ".join(codes)
|
for code in codes:
|
ors.append("first_code='{}'".format(code))
|
|
mysqldb = mysql_data.Mysqldb()
|
results = mysqldb.select_all("select * from ths_industry where {}".format(" or ".join(ors)))
|
|
_fname = None
|
for a in results:
|
_fname = a[0]
|
break
|
print("最终的二级行业名称为:", _fname)
|
return _fname
|
|
|
# 保存单个代码的行业
|
def __save_code_industry(code, code_name, industry_name, zyltgb, zyltgb_unit):
|
mysqldb = mysql_data.Mysqldb()
|
result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code))
|
if result is None:
|
mysqldb.execute(
|
"insert into ths_industry_codes(_id,_name, second_industry,zyltgb,zyltgb_unit) values('{}','{}','{}','{}',{})".format(
|
code, code_name, industry_name, zyltgb, zyltgb_unit, round(time.time() * 1000)))
|
else:
|
if code_name:
|
mysqldb.execute(
|
"update ths_industry_codes set _name='{}', second_industry='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format(
|
code_name, industry_name, zyltgb, zyltgb_unit, code))
|
else:
|
mysqldb.execute(
|
"update ths_industry_codes set second_industry='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format(
|
industry_name, zyltgb, zyltgb_unit, code))
|
|
|
# 保存行业代码
|
def save_industry_code(datasList, code_names):
|
for datas in datasList:
|
# 查询这批数据所属行业
|
industry_name = __get_industry(datas)
|
_list = []
|
for data in datas:
|
# 保存
|
code = data["code"]
|
__save_code_industry(code, code_names.get(code), industry_name, data["zyltgb"], data["zyltgb_unit"])
|
|
|
# 根据名称获取代码
|
def get_code_by_name(name):
|
mysqldb = mysql_data.Mysqldb()
|
result = mysqldb.select_one("select * from ths_industry_codes where _name='{}'".format(name))
|
if result is not None:
|
return result[0]
|
else:
|
return None
|
|
|
def get_name_by_code(code):
|
mysqldb = mysql_data.Mysqldb()
|
result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code))
|
if result is not None:
|
return result[1]
|
else:
|
return None
|
|
if __name__ == "__main__":
|
_code_map, _industry_map = get_code_industry_maps()
|
print(_code_map, _industry_map)
|