Administrator
2022-10-21 892b50e242e3c59a738b92dfdfee1bf1ff8932f2
ths_industry_util.py
@@ -3,18 +3,21 @@
"""
# 同花顺行业
import time
import global_util
import mongo_data
import mysql_data
# 获取行业映射
def get_code_industry_maps():
    __code_map = {}
    __industry_map = {}
    results = mongo_data.find("ths-industry-codes", {})
    mysqldb = mysql_data.Mysqldb()
    results = mysqldb.select_all("select * from ths_industry_codes")
    for r in results:
        code = r["_id"]
        industry = r["second_industry"]
        code = r[0]
        industry = r[1]
        __code_map[code] = industry
        if __industry_map.get(industry) is None:
            __industry_map[industry] = set()
@@ -24,6 +27,8 @@
# 设置行业热度
def set_industry_hot_num(limit_up_datas):
    if limit_up_datas is None:
        return
    industry_hot_dict = {}
    code_industry_map = global_util.code_industry_map
    if code_industry_map is None or len(code_industry_map) == 0:
@@ -44,6 +49,10 @@
        percent = float(data["limitUpPercent"])
        if percent > 21:
            percent = 21
        percent = round(percent, 2)
        # 保存涨幅
        global_util.limit_up_codes_percent[code] = percent
        industry_hot_dict[industry] = round(industry_hot_dict[industry] + percent, 2)
    global_util.industry_hot_num = industry_hot_dict
@@ -66,6 +75,53 @@
    return industry, codes_
# 获取这一批数据的行业
def __get_industry(datas):
    ors = []
    codes = set()
    for data in datas:
        codes.add(data["code"])
    " or ".join(codes)
    for code in codes:
        ors.append("first_code='{}'".format(code))
    mysqldb = mysql_data.Mysqldb()
    results = mysqldb.select_all("select * from ths_industry where {}".format(" or ".join(ors)))
    _fname = None
    for a in results:
        _fname = a[0]
        break
    print("最终的二级行业名称为:", _fname)
    return _fname
# 保存单个代码的行业
def __save_code_industry(code, industry_name, zyltgb, zyltgb_unit):
    mysqldb = mysql_data.Mysqldb()
    result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code))
    if result is None:
        mysqldb.insert(
            "insert into ths_industry_codes(code,industry_name,zyltgb,zyltgb_unit) values('{}','{}','{}',{})".format(
                code, industry_name, zyltgb, zyltgb_unit, round(time.time() * 1000)))
    else:
        mysqldb.update(
            "update ths_industry_codes set industry_name='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format(
                industry_name, zyltgb, zyltgb_unit, code))
# 保存行业代码
def save_industry_code(datasList):
    for datas in datasList:
        # 查询这批数据所属行业
        industry_name = __get_industry(datas)
        _list = []
        for data in datas:
            # 保存
            __save_code_industry(data["code"],industry_name,data["zyltgb"],data["zyltgb_unit"])
if __name__ == "__main__":
    _code_map, _industry_map = get_code_industry_maps()
    print(_code_map, _industry_map)