twitter_db.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. from typing import Optional
  2. from datetime import datetime
  3. import mysql.connector
  4. from mysql.connector import pooling
  5. from server.module.TweetModel import TweetModel
  6. from server.module.UserModel import User
  7. # 数据库连接配置
  8. # MySQL 连接配置
  9. DB_CONFIG = {
  10. 'host': '117.78.31.244',
  11. 'user': 'root',
  12. 'password': 'zh123456',
  13. 'database': 'twitter_spider',
  14. 'charset':'utf8mb4'
  15. }
  16. # 假设 tweet.created_at 的格式是 'Thu Feb 20 00:38:20 +0000 2025'
  17. # 你可以使用 datetime 模块来解析并转换时间格式
  18. def convert_to_mysql_datetime(date_str: str) -> str:
  19. try:
  20. # 解析时间字符串,并转换为 datetime 对象
  21. tweet_datetime = datetime.strptime(date_str, '%a %b %d %H:%M:%S +0000 %Y')
  22. # 将 datetime 对象转换为 MySQL 格式的字符串
  23. return tweet_datetime.strftime('%Y-%m-%d %H:%M:%S')
  24. except ValueError as e:
  25. print(f"Error converting time: {e}")
  26. return None # 或者返回一个默认值
  27. class DatabaseHandler:
  28. def __init__(self):
  29. # 初始化连接池
  30. self.db_config = DB_CONFIG
  31. self.pool = pooling.MySQLConnectionPool(
  32. pool_name="mypool",
  33. pool_size=5, # 设置连接池大小
  34. **DB_CONFIG
  35. )
  36. def get_connection(self):
  37. # 从连接池中获取连接
  38. return self.pool.get_connection()
  39. async def save_tweet(self, tweet, tweet_type, latest_cursor):
  40. conn = self.get_connection()
  41. cursor = conn.cursor()
  42. # 插入推文数据
  43. query = """
  44. INSERT INTO tweets (id, created_at, user_id, text, lang, in_reply_to,
  45. is_quote_status, quote_id, retweeted_tweet_id, possibly_sensitive, quote_count,
  46. reply_count, favorite_count, favorited, view_count, retweet_count, bookmark_count,
  47. bookmarked, place, is_translatable, is_edit_eligible, edits_remaining,tweet_type,next_cursor)
  48. VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,%s,%s)
  49. """
  50. params = (
  51. tweet.id,
  52. convert_to_mysql_datetime(tweet.created_at),
  53. tweet.user.id,
  54. tweet.text,
  55. tweet.lang,
  56. tweet.in_reply_to,
  57. tweet.is_quote_status,
  58. tweet.quote.id if tweet.quote else None,
  59. tweet.retweeted_tweet.id if tweet.retweeted_tweet else None,
  60. tweet.possibly_sensitive,
  61. tweet.quote_count,
  62. tweet.reply_count,
  63. tweet.favorite_count,
  64. tweet.favorited,
  65. tweet.view_count,
  66. tweet.retweet_count,
  67. tweet.bookmark_count,
  68. tweet.bookmarked,
  69. tweet.place if tweet.place else None, # Handle None values
  70. tweet.is_translatable,
  71. tweet.is_edit_eligible,
  72. tweet.edits_remaining,
  73. tweet_type,
  74. latest_cursor
  75. )
  76. try:
  77. cursor.execute(query, params)
  78. conn.commit()
  79. except mysql.connector.Error as err:
  80. print(f"Error inserting tweet: {err}")
  81. conn.rollback() # 如果插入失败,回滚事务
  82. finally:
  83. cursor.close()
  84. conn.close()
  85. # 插入媒体数据
  86. for media in tweet.media:
  87. self.save_media(tweet.id, media)
  88. # 插入hashtags
  89. for hashtag in tweet.hashtags:
  90. self.save_hashtag(tweet.id, hashtag)
  91. # 插入URLs
  92. for url_obj in tweet.urls:
  93. # 提取 expanded_url(完整的 URL)
  94. url = url_obj.get('expanded_url') # 如果你希望保存完整的 URL
  95. if url:
  96. self.save_url(tweet.id, url) # 执行 SQL 插入操作
  97. # 插入回复
  98. if tweet.replies:
  99. for reply in tweet.replies:
  100. self.save_reply(tweet.id, reply.id)
  101. # 插入相关推文(引用或转发)
  102. if tweet.related_tweets:
  103. for related in tweet.related_tweets:
  104. self.save_related_tweet(tweet.id, related.id)
  105. def save_media(self, tweet_id, media):
  106. conn = self.get_connection()
  107. cursor = conn.cursor()
  108. query = """
  109. INSERT INTO media (tweet_id, media_url, media_type)
  110. VALUES (%s, %s, %s)
  111. """
  112. try:
  113. cursor.execute(query, (tweet_id, media.url, media.type))
  114. conn.commit()
  115. except mysql.connector.Error as err:
  116. print(f"Error inserting media: {err}")
  117. conn.rollback() # 如果插入失败,回滚事务
  118. finally:
  119. cursor.close()
  120. conn.close()
  121. def save_hashtag(self, tweet_id, hashtag):
  122. conn = self.get_connection()
  123. cursor = conn.cursor()
  124. query = """
  125. INSERT INTO hashtags (tweet_id, hashtag)
  126. VALUES (%s, %s)
  127. """
  128. try:
  129. cursor.execute(query, (tweet_id, hashtag))
  130. conn.commit()
  131. except mysql.connector.Error as err:
  132. print(f"Error inserting hashtag: {err}")
  133. conn.rollback()
  134. finally:
  135. cursor.close()
  136. conn.close()
  137. def save_url(self, tweet_id, url):
  138. conn = self.get_connection()
  139. cursor = conn.cursor()
  140. query = """
  141. INSERT INTO urls (tweet_id, url)
  142. VALUES (%s, %s)
  143. """
  144. try:
  145. cursor.execute(query, (tweet_id, url))
  146. conn.commit()
  147. except mysql.connector.Error as err:
  148. print(f"Error inserting URL: {err}")
  149. conn.rollback()
  150. finally:
  151. cursor.close()
  152. conn.close()
  153. def save_reply(self, tweet_id, reply_tweet_id):
  154. conn = self.get_connection()
  155. cursor = conn.cursor()
  156. query = """
  157. INSERT INTO replies (tweet_id, reply_tweet_id)
  158. VALUES (%s, %s)
  159. """
  160. try:
  161. cursor.execute(query, (tweet_id, reply_tweet_id))
  162. conn.commit()
  163. except mysql.connector.Error as err:
  164. print(f"Error inserting reply: {err}")
  165. conn.rollback()
  166. finally:
  167. cursor.close()
  168. conn.close()
  169. def save_related_tweet(self, tweet_id, related_tweet_id):
  170. conn = self.get_connection()
  171. cursor = conn.cursor()
  172. query = """
  173. INSERT INTO related_tweets (tweet_id, related_tweet_id)
  174. VALUES (%s, %s)
  175. """
  176. try:
  177. cursor.execute(query, (tweet_id, related_tweet_id))
  178. conn.commit()
  179. except mysql.connector.Error as err:
  180. print(f"Error inserting reply: {err}")
  181. conn.rollback()
  182. finally:
  183. cursor.close()
  184. conn.close()
  185. def get_latest_twitter_id(self, user_id: str, tweet_type: Optional[str] = None) -> Optional[str]:
  186. """获取数据库中某个用户指定类型的最新 twitter_id"""
  187. conn = self.get_connection()
  188. cursor = conn.cursor()
  189. # 构建查询条件
  190. query = "SELECT next_cursor FROM tweets WHERE user_id = %s"
  191. params = [user_id]
  192. if tweet_type:
  193. query += " AND tweet_type = %s"
  194. params.append(tweet_type)
  195. query += " ORDER BY created_at DESC LIMIT 1"
  196. try:
  197. cursor.execute(query, tuple(params))
  198. result = cursor.fetchone()
  199. return result[0] if result else None
  200. except mysql.connector.Error as err:
  201. print(f"Error inserting reply: {err}")
  202. conn.rollback()
  203. finally:
  204. cursor.close()
  205. conn.close()
  206. async def save_user(self, user_data):
  207. conn = self.get_connection()
  208. cursor = conn.cursor()
  209. print(user_data)
  210. # 检查用户是否已存在
  211. check_query = "SELECT COUNT(*) FROM users WHERE id = %s"
  212. cursor.execute(check_query, (user_data['id'],))
  213. result = cursor.fetchone()
  214. if result[0] > 0: # 用户已存在
  215. # 更新用户数据
  216. update_query = """
  217. UPDATE users SET
  218. name = %s,
  219. screen_name = %s,
  220. profile_image_url = %s,
  221. profile_banner_url = %s,
  222. url = %s,
  223. location = %s,
  224. description = %s,
  225. is_blue_verified = %s,
  226. verified = %s,
  227. possibly_sensitive = %s,
  228. can_dm = %s,
  229. can_media_tag = %s,
  230. want_retweets = %s,
  231. default_profile = %s,
  232. default_profile_image = %s,
  233. followers_count = %s,
  234. fast_followers_count = %s,
  235. normal_followers_count = %s,
  236. following_count = %s,
  237. favourites_count = %s,
  238. listed_count = %s,
  239. media_count = %s,
  240. statuses_count = %s,
  241. is_translator = %s,
  242. translator_type = %s,
  243. profile_interstitial_type = %s,
  244. withheld_in_countries = %s
  245. WHERE id = %s
  246. """
  247. cursor.execute(update_query, (
  248. user_data['name'], user_data['screen_name'], user_data['profile_image_url'],
  249. user_data['profile_banner_url'], user_data['url'], user_data['location'],
  250. user_data['description'], user_data['is_blue_verified'], user_data['verified'],
  251. user_data['possibly_sensitive'], user_data['can_dm'], user_data['can_media_tag'],
  252. user_data['want_retweets'], user_data['default_profile'], user_data['default_profile_image'],
  253. user_data['followers_count'], user_data['fast_followers_count'], user_data['normal_followers_count'],
  254. user_data['following_count'], user_data['favourites_count'], user_data['listed_count'],
  255. user_data['media_count'], user_data['statuses_count'], user_data['is_translator'],
  256. user_data['translator_type'], user_data['profile_interstitial_type'],
  257. user_data['withheld_in_countries'], user_data['id']
  258. ))
  259. else: # 用户不存在,执行插入操作
  260. insert_query = """
  261. INSERT INTO users (
  262. id, name, screen_name, profile_image_url, profile_banner_url, url,
  263. location, description, is_blue_verified, verified, possibly_sensitive,
  264. can_dm, can_media_tag, want_retweets, default_profile, default_profile_image,
  265. followers_count, fast_followers_count, normal_followers_count, following_count,
  266. favourites_count, listed_count, media_count, statuses_count, is_translator,
  267. translator_type, profile_interstitial_type, withheld_in_countries
  268. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
  269. """
  270. cursor.execute(insert_query, (
  271. user_data['id'], user_data['name'], user_data['screen_name'], user_data['profile_image_url'],
  272. user_data['profile_banner_url'], user_data['url'], user_data['location'], user_data['description'],
  273. user_data['is_blue_verified'], user_data['verified'], user_data['possibly_sensitive'],
  274. user_data['can_dm'],
  275. user_data['can_media_tag'], user_data['want_retweets'], user_data['default_profile'],
  276. user_data['default_profile_image'], user_data['followers_count'], user_data['fast_followers_count'],
  277. user_data['normal_followers_count'], user_data['following_count'], user_data['favourites_count'],
  278. user_data['listed_count'], user_data['media_count'], user_data['statuses_count'],
  279. user_data['is_translator'],
  280. user_data['translator_type'], user_data['profile_interstitial_type'], user_data['withheld_in_countries']
  281. ))
  282. try:
  283. conn.commit()
  284. except mysql.connector.Error as err:
  285. print(f"Error inserting user: {err}")
  286. conn.rollback()
  287. finally:
  288. cursor.close()
  289. conn.close()
  290. async def get_all_user(self, page: int = 1, page_size: int = 10):
  291. # 连接到数据库
  292. conn = self.get_connection()
  293. cursor = conn.cursor()
  294. # 计算偏移量
  295. offset = (page - 1) * page_size
  296. # 执行查询获取所有用户的 id
  297. query = "SELECT * FROM users LIMIT %s OFFSET %s"
  298. cursor.execute(query, (page_size, offset))
  299. # 获取所有结果,提取 id 列
  300. user_list = cursor.fetchall()
  301. # 获取总记录数
  302. count_query = "SELECT COUNT(*) FROM users"
  303. cursor.execute(count_query)
  304. total_records = cursor.fetchone()[0]
  305. # 关闭游标和连接
  306. cursor.close()
  307. conn.close()
  308. users = [
  309. User(
  310. id=row[0],
  311. name=row[1],
  312. screen_name=row[2],
  313. profile_image_url=row[3],
  314. profile_banner_url=row[4],
  315. url=row[5],
  316. location=row[6],
  317. description=row[7],
  318. is_blue_verified=row[8],
  319. verified=row[9],
  320. possibly_sensitive=row[10],
  321. can_dm=row[11],
  322. can_media_tag=row[12],
  323. want_retweets=row[13],
  324. default_profile=row[14],
  325. default_profile_image=row[15],
  326. followers_count=row[16],
  327. fast_followers_count=row[17],
  328. normal_followers_count=row[18],
  329. following_count=row[19],
  330. favourites_count=row[20],
  331. listed_count=row[21],
  332. media_count=row[22],
  333. statuses_count=row[23],
  334. is_translator=row[24],
  335. translator_type=row[25],
  336. profile_interstitial_type=row[26],
  337. withheld_in_countries=row[27]
  338. ) for row in user_list
  339. ]
  340. # 返回所有用户 id 的列表
  341. return users, total_records
  342. async def get_all_twitter(self, page: int = 1, page_size: int = 10):
  343. """
  344. 获取所有 tweets 并支持分页。
  345. :param page: 当前页码,默认为 1
  346. :param page_size: 每页显示的记录数,默认为 10
  347. :return: tweets 列表
  348. """
  349. # 计算偏移量
  350. offset = (page - 1) * page_size
  351. # 连接到数据库
  352. conn = self.get_connection()
  353. cursor = conn.cursor()
  354. # 执行分页查询获取 tweets
  355. query = "SELECT * FROM tweets LIMIT %s OFFSET %s"
  356. cursor.execute(query, (page_size, offset))
  357. # 获取所有结果
  358. tweets_list = cursor.fetchall()
  359. # 返回查询结果
  360. tweets = [TweetModel(**tweet) for tweet in tweets_list]
  361. # 获取总记录数
  362. count_query = "SELECT COUNT(*) FROM tweets"
  363. cursor.execute(count_query)
  364. total_records = cursor.fetchone()[0]
  365. # 关闭游标和连接
  366. cursor.close()
  367. conn.close()
  368. # 返回所有 tweets 的列表
  369. return tweets, total_records
  370. async def get_all_user_ids(self):
  371. # 连接到数据库
  372. conn = mysql.connector.connect(**self.db_config)
  373. cursor = conn.cursor()
  374. # 执行查询获取所有用户的 id
  375. query = "SELECT id FROM users"
  376. cursor.execute(query)
  377. # 获取所有结果,提取 id 列
  378. user_ids = [row[0] for row in cursor.fetchall()]
  379. # 关闭游标和连接
  380. cursor.close()
  381. conn.close()
  382. # 返回所有用户 id 的列表
  383. return user_ids