]>
Commit | Line | Data |
---|---|---|
1 | from enum import IntEnum, unique | |
2 | import inspect | |
3 | from abc import ABCMeta | |
4 | from baseorm import Observer | |
5 | import logging | |
6 | logging.basicConfig(level=logging.INFO) | |
7 | logger = logging.getLogger(__name__) | |
8 | ||
9 | @unique | |
10 | class DBDriver(IntEnum): | |
11 | PG = 1 | |
12 | MySQL = 2 | |
13 | ||
14 | class Singleton(ABCMeta): | |
15 | _instances = {} | |
16 | def __call__(cls, password, driver = DBDriver.PG, user = 'pwp', database = 'pwp', host = 'localhost', port = None): | |
17 | if cls not in cls._instances: | |
18 | cls._instances[cls] = super(Singleton, cls).__call__(password, driver = DBDriver.PG, user = 'pwp', database = 'pwp', host = 'localhost', port = None) | |
19 | return cls._instances[cls] | |
20 | ||
21 | class DB(Observer, metaclass=Singleton): | |
22 | __filename__ = None | |
23 | ||
24 | def __init__(self, password, driver = DBDriver.PG, user = 'pwp', database = 'pwp', host = 'localhost', port = None): | |
25 | self.password = password | |
26 | self.user = user | |
27 | self.database = database | |
28 | self.driver = driver | |
29 | self.host = host | |
30 | self.port = port | |
31 | self.__connect() | |
32 | if driver == DBDriver.PG: | |
33 | self.__filename__ = 'db_postgres.sql' | |
34 | elif driver == DBDriver.MySQL: | |
35 | self.__filename__ = 'db_mysql.sql' | |
36 | ||
37 | def __connect(self): | |
38 | if self.driver is DBDriver.PG: | |
39 | from postgres import Postgres | |
40 | try: | |
41 | logger.info("Connect: database: {0} user: {1} host: {2} port: {3}".format(self.database, self.user, self.host, self.port)) | |
42 | self.db = Postgres(self.database, self.user, self.password, self.host, self.port) | |
43 | except Exception as e: | |
44 | raise Exception(e) | |
45 | elif self.driver is DBDriver.MySQL: | |
46 | from mysqld import Mysql | |
47 | try: | |
48 | self.db = Mysql(self.database, self.user, self.password, self.host, self.port) | |
49 | except Exception as e: | |
50 | raise Exception(e) | |
51 | else: | |
52 | raise Exception("%s: Unknown DB driver" % self.driver) | |
53 | ||
54 | def update(self, subject): | |
55 | self.store(subject) | |
56 | ||
57 | def initDb(self): | |
58 | file = None | |
59 | try: | |
60 | file = open(self.__filename__, "r") | |
61 | sql = file.read() | |
62 | file.close() | |
63 | self.db.ddl(sql) | |
64 | except: | |
65 | raise | |
66 | ||
67 | def query(self, object, **kwargs): | |
68 | if kwargs: | |
69 | if object.lower() == 'user' and self.driver is DBDriver.PG: | |
70 | sql = 'select * from "{0}" where '.format(object) | |
71 | else: | |
72 | sql = 'select * from {0} where '.format(object) | |
73 | placeholder = [] | |
74 | clause = None | |
75 | for name, value in kwargs.items(): | |
76 | if name.lower() == 'user' and self.driver is DBDriver.PG: | |
77 | attr = '"{0}"'.format(name) | |
78 | else: | |
79 | attr = name | |
80 | if clause: | |
81 | clause += ' and {0} = %s '.format(attr) | |
82 | else: | |
83 | clause = '{0} = %s'.format(attr) | |
84 | placeholder.append(value) | |
85 | sql += clause | |
86 | logger.info("{0} -> {1}".format(sql, placeholder)) | |
87 | return self.db.query(sql, placeholder) | |
88 | else: | |
89 | raise Exception("{0}: Missing at least one query parameter".format(object)) | |
90 | ||
91 | def store(self, object): | |
92 | if inspect.isclass(object): | |
93 | raise Exception("{0}: Class not instance".format(object)) | |
94 | ||
95 | if hasattr(object, '__tablename__'): | |
96 | table = object.__tablename__ | |
97 | else: | |
98 | table = object.__class__.__name__ | |
99 | ||
100 | v = [i for i in dir(object) if isinstance(getattr(type(object), i, None), property)] | |
101 | ||
102 | action = None | |
103 | column = [] | |
104 | values = [] | |
105 | sql = None | |
106 | id = None | |
107 | for p in v: | |
108 | value = getattr(object, p) | |
109 | if p == 'id': | |
110 | id = value | |
111 | if value < 0: | |
112 | # insert | |
113 | action = 'insert' | |
114 | else: | |
115 | # update | |
116 | action = 'update' | |
117 | else: | |
118 | if hasattr(value, 'id'): | |
119 | values.append(value.id) | |
120 | else: | |
121 | values.append(value) | |
122 | column.append(p) | |
123 | if action == 'insert': | |
124 | if table.lower() == 'user' and self.driver is DBDriver.PG: | |
125 | sql = 'insert into "' + table + '" (' | |
126 | else: | |
127 | sql = 'insert into ' + table + ' (' | |
128 | for i in range(0, len(column)): | |
129 | field = column[i] | |
130 | if field.lower() == 'user' and self.driver is DBDriver.PG: | |
131 | field = '"{0}"'.format(field) | |
132 | if i == 0: | |
133 | sql += field | |
134 | else: | |
135 | sql += ', ' + field | |
136 | sql += ') values (' | |
137 | for i in range(0, len(values)): | |
138 | if i == 0: | |
139 | sql += '%s' | |
140 | else: | |
141 | sql += ', %s' | |
142 | sql += ') ' | |
143 | object.id = self.db.insert(sql, values) | |
144 | else: | |
145 | if table.lower() == 'user' and self.driver is DBDriver.PG: | |
146 | sql = 'update "' + table + '" set ' | |
147 | else: | |
148 | sql = 'update ' + table + ' set ' | |
149 | for i in range(0, len(column)): | |
150 | if i == 0: | |
151 | sql += '{0} = %s'.format(column[i]) | |
152 | else: | |
153 | sql += ', {0} = %s'.format(column[i]) | |
154 | sql += ' where id = {0}'.format(id) | |
155 | self.db.update(sql, values) | |
156 | ||
157 | def __repr__(self): | |
158 | default = '<%s.%s object at %s>' % (self.__class__.__module__, self.__class__.__name__, hex(id(self))) | |
159 | return "%s <Driver: %s, User: %s, database: %s>" % (default, self.driver.name, self.user, self.database) | |
160 | ||
161 | def __del__(self): | |
162 | #print("%s: deleted" % self.__class__.__name__) | |
163 | try: | |
164 | del self.db | |
165 | except AttributeError: | |
166 | pass | |
167 | ||
168 | class Test(object): | |
169 | def __init__(self, x): | |
170 | self.x = x | |
171 | ||
172 | @property | |
173 | def x(self): | |
174 | return self.__x | |
175 | ||
176 | @x.setter | |
177 | def x(self, x): | |
178 | if x < 0: | |
179 | self.__x = 0 | |
180 | elif x > 1000: | |
181 | self.__x = 1000 | |
182 | else: | |
183 | self.__x = x | |
184 | ||
185 | if __name__ == "__main__": | |
186 | import sys | |
187 | sys.path.append('../..') | |
188 | from app.models import User, Portfolio, Album, Photo, AccessRight | |
189 | #db = DB('test', DBDriver.MySQL) | |
190 | db = DB('test') | |
191 | db.initDb() | |
192 | ||
193 | user = User('test', 'test@test.dk', 'test', 'sha256$1HX2n73E$ac27f843b4342df7b6c12e5ac340e063ea958d52ce62c3883c124385c96b263a') | |
194 | user.addObserver(db) | |
195 | print(user) | |
196 | db.store(user) | |
197 | portfolio = Portfolio('test', user) | |
198 | print(portfolio) | |
199 | album = Album('test', portfolio) | |
200 | print(album) | |
201 | accessright = AccessRight(user) | |
202 | print(accessright) | |
203 | photo = Photo('test', album) | |
204 | print(photo) | |
205 | db.store(portfolio) | |
206 | user.name = 'MIR' | |
207 | del db |