Administrator
5 天以前 66af232f6d8fb9adee08967ff932f81c37da8d43
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
同花顺行业工具
"""
 
# 同花顺行业
import time
 
from code_attribute import global_data_loader
from db.mysql_data_delegate import Mysqldb
from utils import global_util, tool
from db import mysql_data_delegate as mysql_data
 
 
# 获取行业映射
 
class ThsCodeIndustryManager:
    __instance = None
    __mysql = Mysqldb()
    __code_industry = {}
 
    def __new__(cls, *args, **kwargs):
        if not cls.__instance:
            cls.__instance = super(ThsCodeIndustryManager, cls).__new__(cls, *args, **kwargs)
            cls.__load_data()
        return cls.__instance
 
    @classmethod
    def __load_data(cls):
        results = cls.__mysql.select_all("select _id,second_industry from ths_industry_codes")
        if results:
            for r in results:
                code = r[0]
                industry = r[1]
                cls.__code_industry[code] = industry
 
    def get_industry(self, code):
        return self.__code_industry.get(code)
 
 
def get_code_industry_maps():
    __code_map = {}
    __industry_map = {}
    mysqldb = mysql_data.Mysqldb()
    results = mysqldb.select_all("select _id,second_industry from ths_industry_codes")
    if results:
        for r in results:
            code = r[0]
            industry = r[1]
            __code_map[code] = industry
            if __industry_map.get(industry) is None:
                __industry_map[industry] = set()
            __industry_map[industry].add(code)
    return __code_map, __industry_map
 
 
# 设置行业热度
def set_industry_hot_num(limit_up_datas):
    if limit_up_datas is None:
        return
    industry_hot_dict = {}
    code_industry_map = global_util.code_industry_map
    if code_industry_map is None or len(code_industry_map) == 0:
        global_data_loader.load_industry()
        code_industry_map = global_util.code_industry_map
    if code_industry_map is None:
        raise Exception("获取代码对应的行业出错")
 
    now_str = tool.get_now_time_str()
    for data in limit_up_datas:
        # 时间比现在早的时间才算数
        if data["time"] != "00:00:00" and tool.get_time_as_second(now_str) < tool.get_time_as_second(
                data["time"]):
            continue
 
        code = data["code"]
        industry = code_industry_map.get(code)
        if industry is None:
            # 获取代码对应的行业出错
            continue
        if industry_hot_dict.get(industry) is None:
            industry_hot_dict.setdefault(industry, 0)
 
        percent = float(data["limitUpPercent"])
        if percent > 21:
            percent = 21
        percent = round(percent, 2)
        # 保存涨幅
        global_util.limit_up_codes_percent[code] = percent
 
        industry_hot_dict[industry] = round(industry_hot_dict[industry] + percent, 2)
 
    global_util.industry_hot_num = industry_hot_dict
 
 
# 获取相同行业的代码
# 返回:行业,同行业代码
def get_same_industry_codes(code, codes):
    industry = global_util.code_industry_map.get(code)
    if industry is None:
        global_data_loader.load_industry()
        industry = global_util.code_industry_map.get(code)
    if industry is None:
        return None, None
    codes_ = set()
    for code_ in codes:
        if global_util.code_industry_map.get(code_) == industry:
            # 同一行业
            codes_.add(code_)
    return industry, codes_
 
 
# 获取这一批数据的行业
def __get_industry(datas):
    ors = []
    codes = set()
    for data in datas:
        codes.add(data["code"])
 
    " or ".join(codes)
    for code in codes:
        ors.append("first_code='{}'".format(code))
 
    mysqldb = mysql_data.Mysqldb()
    results = mysqldb.select_all("select * from ths_industry where {}".format(" or ".join(ors)))
 
    _fname = None
    for a in results:
        _fname = a[0]
        break
    print("最终的二级行业名称为:", _fname)
    return _fname
 
 
# 保存单个代码的行业
def __save_code_industry(code, code_name, industry_name, zyltgb, zyltgb_unit):
    mysqldb = mysql_data.Mysqldb()
    result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code))
    if result is None:
        mysqldb.execute(
            "insert into ths_industry_codes(_id,_name, second_industry,zyltgb,zyltgb_unit) values('{}','{}','{}','{}',{})".format(
                code, code_name, industry_name, zyltgb, zyltgb_unit, round(time.time() * 1000)))
    else:
        if code_name:
            mysqldb.execute(
                "update ths_industry_codes set _name='{}', second_industry='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format(
                    code_name, industry_name, zyltgb, zyltgb_unit, code))
        else:
            mysqldb.execute(
                "update ths_industry_codes set second_industry='{}',zyltgb='{}',zyltgb_unit={} where _id='{}'".format(
                    industry_name, zyltgb, zyltgb_unit, code))
 
 
# 保存行业代码
def save_industry_code(datasList, code_names):
    for datas in datasList:
        # 查询这批数据所属行业
        industry_name = __get_industry(datas)
        _list = []
        for data in datas:
            # 保存
            code = data["code"]
            __save_code_industry(code, code_names.get(code), industry_name, data["zyltgb"], data["zyltgb_unit"])
 
 
def save_code_industry(code, code_name, industry):
    __save_code_industry(code, code_name, industry, 0, 0)
 
 
# 根据名称获取代码
def get_code_by_name(name):
    mysqldb = mysql_data.Mysqldb()
    result = mysqldb.select_one("select * from ths_industry_codes where _name='{}'".format(name))
    if result is not None:
        return result[0]
    else:
        return None
 
 
def get_name_by_code(code):
    mysqldb = mysql_data.Mysqldb()
    result = mysqldb.select_one("select * from ths_industry_codes where _id={}".format(code))
    if result is not None:
        return result[1]
    else:
        return None
 
 
if __name__ == "__main__":
    _code_map, _industry_map = get_code_industry_maps()
    print(_code_map, _industry_map)