Administrator
2023-03-27 8535f56dbf6e410b4a09f02f95d4d49bcc8753f2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
同花顺行业工具
"""
 
# 同花顺行业
import time
 
import global_data_loader
import global_util
from db import mysql_data
 
# 获取行业映射
import tool
 
 
def get_code_industry_maps():
    __code_map = {}
    __industry_map = {}
    mysqldb = mysql_data.Mysqldb()
    results = mysqldb.select_all("select * from ths_industry_codes")
    for r in results:
        code = r[0]
        industry = r[1]
        __code_map[code] = industry
        if __industry_map.get(industry) is None:
            __industry_map[industry] = set()
        __industry_map[industry].add(code)
    return __code_map, __industry_map
 
 
# 设置行业热度
def set_industry_hot_num(limit_up_datas):
    if limit_up_datas is None:
        return
    industry_hot_dict = {}
    code_industry_map = global_util.code_industry_map
    if code_industry_map is None or len(code_industry_map) == 0:
        global_data_loader.load_industry()
        code_industry_map = global_util.code_industry_map
    if code_industry_map is None:
        raise Exception("获取代码对应的行业出错")
 
    now_str = tool.get_now_time_str()
    for data in limit_up_datas:
        # 时间比现在早的时间才算数
        if data["time"] != "00:00:00" and tool.get_time_as_second(now_str) < tool.get_time_as_second(
                data["time"]):
            continue
 
        code = data["code"]
        industry = code_industry_map.get(code)
        if industry is None:
            # 获取代码对应的行业出错
            continue
        if industry_hot_dict.get(industry) is None:
            industry_hot_dict.setdefault(industry, 0)
 
        percent = float(data["limitUpPercent"])
        if percent > 21:
            percent = 21
        percent = round(percent, 2)
        # 保存涨幅
        global_util.limit_up_codes_percent[code] = percent
 
        industry_hot_dict[industry] = round(industry_hot_dict[industry] + percent, 2)
 
    global_util.industry_hot_num = industry_hot_dict
 
 
# 获取相同行业的代码
# 返回:行业,同行业代码
def get_same_industry_codes(code, codes):
    industry = global_util.code_industry_map.get(code)
    if industry is None:
        global_data_loader.load_industry()
        industry = global_util.code_industry_map.get(code)
    if industry is None:
        return None, None
    codes_ = set()
    for code_ in codes:
        if global_util.code_industry_map.get(code_) == industry:
            # 同一行业
            codes_.add(code_)
    return industry, codes_
 
 
# 获取这一批数据的行业
def __get_industry(datas):
    ors = []
    codes = set()
    for data in datas:
        codes.add(data["code"])
 
    " or ".join(codes)
    for code in codes:
        ors.append("first_code='{}'".format(code))
 
    mysqldb = mysql_data.Mysqldb()
    results = mysqldb.select_all("select * from ths_industry where {}".format(" or ".join(ors)))
 
    _fname = None
    for a in results:
        _fname = a[0]
        break
    print("最终的二级行业名称为:", _fname)
    return _fname
 
 
# 保存单个代码的行业
def __save_code_industry(code, code_name, industry_name, zyltgb, zyltgb_unit):
    mysqldb = mysql_data.Mysqldb()
    result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code))
    if result is None:
        mysqldb.execute(
            "insert into ths_industry_codes(_id,_name, second_industry,zyltgb,zyltgb_unit) values('{}','{}','{}','{}',{})".format(
                code, code_name, industry_name, zyltgb, zyltgb_unit, round(time.time() * 1000)))
    else:
        if code_name:
            mysqldb.execute(
                "update ths_industry_codes set _name='{}', second_industry='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format(
                    code_name, industry_name, zyltgb, zyltgb_unit, code))
        else:
            mysqldb.execute(
                "update ths_industry_codes set second_industry='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format(
                     industry_name, zyltgb, zyltgb_unit, code))
 
 
# 保存行业代码
def save_industry_code(datasList, code_names):
    for datas in datasList:
        # 查询这批数据所属行业
        industry_name = __get_industry(datas)
        _list = []
        for data in datas:
            # 保存
            code = data["code"]
            __save_code_industry(code, code_names.get(code), industry_name, data["zyltgb"], data["zyltgb_unit"])
 
 
# 根据名称获取代码
def get_code_by_name(name):
    mysqldb = mysql_data.Mysqldb()
    result = mysqldb.select_one("select * from ths_industry_codes where _name='{}'".format(name))
    if result is not None:
        return result[0]
    else:
        return None
 
 
def get_name_by_code(code):
    mysqldb = mysql_data.Mysqldb()
    result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code))
    if result is not None:
        return result[1]
    else:
        return None
 
if __name__ == "__main__":
    _code_map, _industry_map = get_code_industry_maps()
    print(_code_map, _industry_map)