from typing import Optional from datetime import datetime import mysql.connector from mysql.connector import pooling from server.module.TweetModel import TweetModel from server.module.UserModel import User # 数据库连接配置 # MySQL 连接配置 DB_CONFIG = { 'host': '117.78.31.244', 'user': 'root', 'password': 'zh123456', 'database': 'twitter_spider', 'charset':'utf8mb4' } # 假设 tweet.created_at 的格式是 'Thu Feb 20 00:38:20 +0000 2025' # 你可以使用 datetime 模块来解析并转换时间格式 def convert_to_mysql_datetime(date_str: str) -> str: try: # 解析时间字符串,并转换为 datetime 对象 tweet_datetime = datetime.strptime(date_str, '%a %b %d %H:%M:%S +0000 %Y') # 将 datetime 对象转换为 MySQL 格式的字符串 return tweet_datetime.strftime('%Y-%m-%d %H:%M:%S') except ValueError as e: print(f"Error converting time: {e}") return None # 或者返回一个默认值 class DatabaseHandler: def __init__(self): # 初始化连接池 self.db_config = DB_CONFIG self.pool = pooling.MySQLConnectionPool( pool_name="mypool", pool_size=5, # 设置连接池大小 **DB_CONFIG ) def get_connection(self): # 从连接池中获取连接 return self.pool.get_connection() async def save_tweet(self, tweet, tweet_type, latest_cursor): conn = self.get_connection() cursor = conn.cursor() # 插入推文数据 query = """ INSERT INTO tweets (id, created_at, user_id, text, lang, in_reply_to, is_quote_status, quote_id, retweeted_tweet_id, possibly_sensitive, quote_count, reply_count, favorite_count, favorited, view_count, retweet_count, bookmark_count, bookmarked, place, is_translatable, is_edit_eligible, edits_remaining,tweet_type,next_cursor) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,%s,%s) """ params = ( tweet.id, convert_to_mysql_datetime(tweet.created_at), tweet.user.id, tweet.text, tweet.lang, tweet.in_reply_to, tweet.is_quote_status, tweet.quote.id if tweet.quote else None, tweet.retweeted_tweet.id if tweet.retweeted_tweet else None, tweet.possibly_sensitive, tweet.quote_count, tweet.reply_count, tweet.favorite_count, tweet.favorited, tweet.view_count, tweet.retweet_count, tweet.bookmark_count, tweet.bookmarked, tweet.place if tweet.place else None, # Handle None values tweet.is_translatable, tweet.is_edit_eligible, tweet.edits_remaining, tweet_type, latest_cursor ) try: cursor.execute(query, params) conn.commit() except mysql.connector.Error as err: print(f"Error inserting tweet: {err}") conn.rollback() # 如果插入失败,回滚事务 finally: cursor.close() conn.close() # 插入媒体数据 for media in tweet.media: self.save_media(tweet.id, media) # 插入hashtags for hashtag in tweet.hashtags: self.save_hashtag(tweet.id, hashtag) # 插入URLs for url_obj in tweet.urls: # 提取 expanded_url(完整的 URL) url = url_obj.get('expanded_url') # 如果你希望保存完整的 URL if url: self.save_url(tweet.id, url) # 执行 SQL 插入操作 # 插入回复 if tweet.replies: for reply in tweet.replies: self.save_reply(tweet.id, reply.id) # 插入相关推文(引用或转发) if tweet.related_tweets: for related in tweet.related_tweets: self.save_related_tweet(tweet.id, related.id) def save_media(self, tweet_id, media): conn = self.get_connection() cursor = conn.cursor() query = """ INSERT INTO media (tweet_id, media_url, media_type) VALUES (%s, %s, %s) """ try: cursor.execute(query, (tweet_id, media.url, media.type)) conn.commit() except mysql.connector.Error as err: print(f"Error inserting media: {err}") conn.rollback() # 如果插入失败,回滚事务 finally: cursor.close() conn.close() def save_hashtag(self, tweet_id, hashtag): conn = self.get_connection() cursor = conn.cursor() query = """ INSERT INTO hashtags (tweet_id, hashtag) VALUES (%s, %s) """ try: cursor.execute(query, (tweet_id, hashtag)) conn.commit() except mysql.connector.Error as err: print(f"Error inserting hashtag: {err}") conn.rollback() finally: cursor.close() conn.close() def save_url(self, tweet_id, url): conn = self.get_connection() cursor = conn.cursor() query = """ INSERT INTO urls (tweet_id, url) VALUES (%s, %s) """ try: cursor.execute(query, (tweet_id, url)) conn.commit() except mysql.connector.Error as err: print(f"Error inserting URL: {err}") conn.rollback() finally: cursor.close() conn.close() def save_reply(self, tweet_id, reply_tweet_id): conn = self.get_connection() cursor = conn.cursor() query = """ INSERT INTO replies (tweet_id, reply_tweet_id) VALUES (%s, %s) """ try: cursor.execute(query, (tweet_id, reply_tweet_id)) conn.commit() except mysql.connector.Error as err: print(f"Error inserting reply: {err}") conn.rollback() finally: cursor.close() conn.close() def save_related_tweet(self, tweet_id, related_tweet_id): conn = self.get_connection() cursor = conn.cursor() query = """ INSERT INTO related_tweets (tweet_id, related_tweet_id) VALUES (%s, %s) """ try: cursor.execute(query, (tweet_id, related_tweet_id)) conn.commit() except mysql.connector.Error as err: print(f"Error inserting reply: {err}") conn.rollback() finally: cursor.close() conn.close() def get_latest_twitter_id(self, user_id: str, tweet_type: Optional[str] = None) -> Optional[str]: """获取数据库中某个用户指定类型的最新 twitter_id""" conn = self.get_connection() cursor = conn.cursor() # 构建查询条件 query = "SELECT next_cursor FROM tweets WHERE user_id = %s" params = [user_id] if tweet_type: query += " AND tweet_type = %s" params.append(tweet_type) query += " ORDER BY created_at DESC LIMIT 1" try: cursor.execute(query, tuple(params)) result = cursor.fetchone() return result[0] if result else None except mysql.connector.Error as err: print(f"Error inserting reply: {err}") conn.rollback() finally: cursor.close() conn.close() async def save_user(self, user_data): conn = self.get_connection() cursor = conn.cursor() print(user_data) # 检查用户是否已存在 check_query = "SELECT COUNT(*) FROM users WHERE id = %s" cursor.execute(check_query, (user_data['id'],)) result = cursor.fetchone() if result[0] > 0: # 用户已存在 # 更新用户数据 update_query = """ UPDATE users SET name = %s, screen_name = %s, profile_image_url = %s, profile_banner_url = %s, url = %s, location = %s, description = %s, is_blue_verified = %s, verified = %s, possibly_sensitive = %s, can_dm = %s, can_media_tag = %s, want_retweets = %s, default_profile = %s, default_profile_image = %s, followers_count = %s, fast_followers_count = %s, normal_followers_count = %s, following_count = %s, favourites_count = %s, listed_count = %s, media_count = %s, statuses_count = %s, is_translator = %s, translator_type = %s, profile_interstitial_type = %s, withheld_in_countries = %s WHERE id = %s """ cursor.execute(update_query, ( user_data['name'], user_data['screen_name'], user_data['profile_image_url'], user_data['profile_banner_url'], user_data['url'], user_data['location'], user_data['description'], user_data['is_blue_verified'], user_data['verified'], user_data['possibly_sensitive'], user_data['can_dm'], user_data['can_media_tag'], user_data['want_retweets'], user_data['default_profile'], user_data['default_profile_image'], user_data['followers_count'], user_data['fast_followers_count'], user_data['normal_followers_count'], user_data['following_count'], user_data['favourites_count'], user_data['listed_count'], user_data['media_count'], user_data['statuses_count'], user_data['is_translator'], user_data['translator_type'], user_data['profile_interstitial_type'], user_data['withheld_in_countries'], user_data['id'] )) else: # 用户不存在,执行插入操作 insert_query = """ INSERT INTO users ( id, name, screen_name, profile_image_url, profile_banner_url, url, location, description, is_blue_verified, verified, possibly_sensitive, can_dm, can_media_tag, want_retweets, default_profile, default_profile_image, followers_count, fast_followers_count, normal_followers_count, following_count, favourites_count, listed_count, media_count, statuses_count, is_translator, translator_type, profile_interstitial_type, withheld_in_countries ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """ cursor.execute(insert_query, ( user_data['id'], user_data['name'], user_data['screen_name'], user_data['profile_image_url'], user_data['profile_banner_url'], user_data['url'], user_data['location'], user_data['description'], user_data['is_blue_verified'], user_data['verified'], user_data['possibly_sensitive'], user_data['can_dm'], user_data['can_media_tag'], user_data['want_retweets'], user_data['default_profile'], user_data['default_profile_image'], user_data['followers_count'], user_data['fast_followers_count'], user_data['normal_followers_count'], user_data['following_count'], user_data['favourites_count'], user_data['listed_count'], user_data['media_count'], user_data['statuses_count'], user_data['is_translator'], user_data['translator_type'], user_data['profile_interstitial_type'], user_data['withheld_in_countries'] )) try: conn.commit() except mysql.connector.Error as err: print(f"Error inserting user: {err}") conn.rollback() finally: cursor.close() conn.close() async def get_all_user(self, page: int = 1, page_size: int = 10): # 连接到数据库 conn = self.get_connection() cursor = conn.cursor() # 计算偏移量 offset = (page - 1) * page_size # 执行查询获取所有用户的 id query = "SELECT * FROM users LIMIT %s OFFSET %s" cursor.execute(query, (page_size, offset)) # 获取所有结果,提取 id 列 user_list = cursor.fetchall() # 获取总记录数 count_query = "SELECT COUNT(*) FROM users" cursor.execute(count_query) total_records = cursor.fetchone()[0] # 关闭游标和连接 cursor.close() conn.close() users = [ User( id=row[0], name=row[1], screen_name=row[2], profile_image_url=row[3], profile_banner_url=row[4], url=row[5], location=row[6], description=row[7], is_blue_verified=row[8], verified=row[9], possibly_sensitive=row[10], can_dm=row[11], can_media_tag=row[12], want_retweets=row[13], default_profile=row[14], default_profile_image=row[15], followers_count=row[16], fast_followers_count=row[17], normal_followers_count=row[18], following_count=row[19], favourites_count=row[20], listed_count=row[21], media_count=row[22], statuses_count=row[23], is_translator=row[24], translator_type=row[25], profile_interstitial_type=row[26], withheld_in_countries=row[27] ) for row in user_list ] # 返回所有用户 id 的列表 return users, total_records async def get_all_twitter(self, page: int = 1, page_size: int = 10): """ 获取所有 tweets 并支持分页。 :param page: 当前页码,默认为 1 :param page_size: 每页显示的记录数,默认为 10 :return: tweets 列表 """ # 计算偏移量 offset = (page - 1) * page_size # 连接到数据库 conn = self.get_connection() cursor = conn.cursor() # 执行分页查询获取 tweets query = "SELECT * FROM tweets LIMIT %s OFFSET %s" cursor.execute(query, (page_size, offset)) # 获取所有结果 tweets_list = cursor.fetchall() # 返回查询结果 tweets = [TweetModel(**tweet) for tweet in tweets_list] # 获取总记录数 count_query = "SELECT COUNT(*) FROM tweets" cursor.execute(count_query) total_records = cursor.fetchone()[0] # 关闭游标和连接 cursor.close() conn.close() # 返回所有 tweets 的列表 return tweets, total_records async def get_all_user_ids(self): # 连接到数据库 conn = mysql.connector.connect(**self.db_config) cursor = conn.cursor() # 执行查询获取所有用户的 id query = "SELECT id FROM users" cursor.execute(query) # 获取所有结果,提取 id 列 user_ids = [row[0] for row in cursor.fetchall()] # 关闭游标和连接 cursor.close() conn.close() # 返回所有用户 id 的列表 return user_ids