Administrator
2023-01-16 6f324f1471a5e28188e9f4206b46cbafdf09d04c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
同花顺行业工具
"""
 
# 同花顺行业
import datetime
import time
 
import global_data_loader
import global_util
import mysql_data
 
# 获取行业映射
import tool
 
 
def get_code_industry_maps():
    __code_map = {}
    __industry_map = {}
    mysqldb = mysql_data.Mysqldb()
    results = mysqldb.select_all("select * from ths_industry_codes")
    for r in results:
        code = r[0]
        industry = r[1]
        __code_map[code] = industry
        if __industry_map.get(industry) is None:
            __industry_map[industry] = set()
        __industry_map[industry].add(code)
    return __code_map, __industry_map
 
 
# 设置行业热度
def set_industry_hot_num(limit_up_datas):
    if limit_up_datas is None:
        return
    industry_hot_dict = {}
    code_industry_map = global_util.code_industry_map
    if code_industry_map is None or len(code_industry_map) == 0:
        global_data_loader.load_industry()
        code_industry_map = global_util.code_industry_map
    if code_industry_map is None:
        raise Exception("获取代码对应的行业出错")
 
    now_str = tool.get_now_time_str()
    for data in limit_up_datas:
        # 时间比现在早的时间才算数
        if data["time"] != "00:00:00" and tool.get_time_as_second(now_str) < tool.get_time_as_second(
                data["time"]):
            continue
 
        code = data["code"]
        industry = code_industry_map.get(code)
        if industry is None:
            # 获取代码对应的行业出错
            continue
        if industry_hot_dict.get(industry) is None:
            industry_hot_dict.setdefault(industry, 0)
 
        percent = float(data["limitUpPercent"])
        if percent > 21:
            percent = 21
        percent = round(percent, 2)
        # 保存涨幅
        global_util.limit_up_codes_percent[code] = percent
 
        industry_hot_dict[industry] = round(industry_hot_dict[industry] + percent, 2)
 
    global_util.industry_hot_num = industry_hot_dict
 
 
# 获取相同行业的代码
# 返回:行业,同行业代码
def get_same_industry_codes(code, codes):
    industry = global_util.code_industry_map.get(code)
    if industry is None:
        global_data_loader.load_industry()
        industry = global_util.code_industry_map.get(code)
    if industry is None:
        return None, None
    codes_ = set()
    for code_ in codes:
        if global_util.code_industry_map.get(code_) == industry:
            # 同一行业
            codes_.add(code_)
    return industry, codes_
 
 
# 获取这一批数据的行业
def __get_industry(datas):
    ors = []
    codes = set()
    for data in datas:
        codes.add(data["code"])
 
    " or ".join(codes)
    for code in codes:
        ors.append("first_code='{}'".format(code))
 
    mysqldb = mysql_data.Mysqldb()
    results = mysqldb.select_all("select * from ths_industry where {}".format(" or ".join(ors)))
 
    _fname = None
    for a in results:
        _fname = a[0]
        break
    print("最终的二级行业名称为:", _fname)
    return _fname
 
 
# 保存单个代码的行业
def __save_code_industry(code, industry_name, zyltgb, zyltgb_unit):
    mysqldb = mysql_data.Mysqldb()
    result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code))
    if result is None:
        mysqldb.execute(
            "insert into ths_industry_codes(_id,second_industry,zyltgb,zyltgb_unit) values('{}','{}','{}',{})".format(
                code, industry_name, zyltgb, zyltgb_unit, round(time.time() * 1000)))
    else:
        mysqldb.execute(
            "update ths_industry_codes set second_industry='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format(
                industry_name, zyltgb, zyltgb_unit, code))
 
 
# 保存行业代码
def save_industry_code(datasList):
    for datas in datasList:
        # 查询这批数据所属行业
        industry_name = __get_industry(datas)
        _list = []
        for data in datas:
            # 保存
            __save_code_industry(data["code"], industry_name, data["zyltgb"], data["zyltgb_unit"])
 
 
if __name__ == "__main__":
    _code_map, _industry_map = get_code_industry_maps()
    print(_code_map, _industry_map)