Initial database setup.

This commit is contained in:
2025-07-21 12:35:18 +02:00
parent b617016b61
commit e80978681a
3 changed files with 200 additions and 53 deletions

136
database.py Normal file
View File

@@ -0,0 +1,136 @@
# database.py
import sqlite3
import time
DB_FILE = "reddit_stocks.db"
def get_db_connection():
"""Establishes a connection to the SQLite database."""
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
return conn
def initialize_db():
"""Initializes the database and creates tables if they don't exist."""
conn = get_db_connection()
cursor = conn.cursor()
# --- Create tickers table ---
cursor.execute("""
CREATE TABLE IF NOT EXISTS tickers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL UNIQUE,
market_cap INTEGER,
last_updated INTEGER
)
""")
# --- Create subreddits table ---
cursor.execute("""
CREATE TABLE IF NOT EXISTS subreddits (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE
)
""")
# --- Create mentions table ---
cursor.execute("""
CREATE TABLE IF NOT EXISTS mentions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ticker_id INTEGER,
subreddit_id INTEGER,
post_id TEXT NOT NULL,
mention_timestamp INTEGER NOT NULL,
sentiment_score REAL,
FOREIGN KEY (ticker_id) REFERENCES tickers (id),
FOREIGN KEY (subreddit_id) REFERENCES subreddits (id),
UNIQUE(ticker_id, post_id)
)
""")
conn.commit()
conn.close()
print("Database initialized successfully.")
def get_or_create_entity(conn, table_name, column_name, value):
"""Generic function to get or create an entity and return its ID."""
cursor = conn.cursor()
cursor.execute(f"SELECT id FROM {table_name} WHERE {column_name} = ?", (value,))
result = cursor.fetchone()
if result:
return result['id']
else:
cursor.execute(f"INSERT INTO {table_name} ({column_name}) VALUES (?)", (value,))
conn.commit()
return cursor.lastrowid
def add_mention(conn, ticker_id, subreddit_id, post_id, timestamp):
"""Adds a new mention to the database, ignoring duplicates."""
cursor = conn.cursor()
try:
cursor.execute(
"INSERT INTO mentions (ticker_id, subreddit_id, post_id, mention_timestamp) VALUES (?, ?, ?, ?)",
(ticker_id, subreddit_id, post_id, timestamp)
)
conn.commit()
except sqlite3.IntegrityError:
pass
def update_ticker_market_cap(conn, ticker_id, market_cap):
"""Updates the market cap and timestamp for a specific ticker."""
cursor = conn.cursor()
current_timestamp = int(time.time())
cursor.execute(
"UPDATE tickers SET market_cap = ?, last_updated = ? WHERE id = ?",
(market_cap, current_timestamp, ticker_id)
)
conn.commit()
def get_ticker_info(conn, ticker_id):
"""Retrieves all info for a specific ticker by its ID."""
cursor = conn.cursor()
cursor.execute("SELECT * FROM tickers WHERE id = ?", (ticker_id,))
return cursor.fetchone()
def generate_summary_report():
"""Queries the DB to generate and print a summary with market caps."""
print("\n--- Summary Report ---")
conn = get_db_connection()
cursor = conn.cursor()
query = """
SELECT
t.symbol,
t.market_cap,
COUNT(m.id) as mention_count
FROM mentions m
JOIN tickers t ON m.ticker_id = t.id
GROUP BY t.symbol, t.market_cap
ORDER BY mention_count DESC
LIMIT 20;
"""
results = cursor.execute(query).fetchall()
print(f"{'Ticker':<10} | {'Mentions':<10} | {'Market Cap':<20}")
print("-" * 45)
for row in results:
market_cap_str = "N/A"
if row['market_cap']:
# Format market cap into a readable string (e.g., $1.23T, $45.6B, $123.4M)
mc = row['market_cap']
if mc >= 1e12:
market_cap_str = f"${mc/1e12:.2f}T"
elif mc >= 1e9:
market_cap_str = f"${mc/1e9:.2f}B"
elif mc >= 1e6:
market_cap_str = f"${mc/1e6:.2f}M"
else:
market_cap_str = f"${mc:,}"
print(f"{row['symbol']:<10} | {row['mention_count']:<10} | {market_cap_str:<20}")
conn.close()

101
main.py
View File

@@ -3,19 +3,22 @@
import argparse import argparse
import json import json
import os import os
from collections import Counter import time
import praw import praw
import yfinance as yf import yfinance as yf
from dotenv import load_dotenv from dotenv import load_dotenv
import database
from ticker_extractor import extract_tickers from ticker_extractor import extract_tickers
# Load environment variables from .env file # Load environment variables from .env file
load_dotenv() load_dotenv()
# How old (in seconds) market cap data can be before we refresh it. 24 hours = 86400 seconds.
MARKET_CAP_REFRESH_INTERVAL = 86400
def load_subreddits(filepath): def load_subreddits(filepath):
"""Loads a list of subreddits from a JSON file.""" # (This function is unchanged)
try: try:
with open(filepath, 'r') as f: with open(filepath, 'r') as f:
data = json.load(f) data = json.load(f)
@@ -28,21 +31,17 @@ def load_subreddits(filepath):
return None return None
def get_market_cap(ticker_symbol): def get_market_cap(ticker_symbol):
"""Fetches the market capitalization for a given stock ticker.""" """Fetches the market capitalization for a given stock ticker from yfinance."""
try: try:
ticker = yf.Ticker(ticker_symbol) ticker = yf.Ticker(ticker_symbol)
market_cap = ticker.info.get('marketCap') # .info can be slow; .fast_info is a lighter alternative
if market_cap: market_cap = ticker.fast_info.get('marketCap')
# Formatting for better readability return market_cap if market_cap else None
return f"${market_cap:,}" except Exception:
return None
return "N/A"
except Exception as e:
# yfinance can sometimes fail for various reasons (e.g., invalid ticker)
return "N/A"
def get_reddit_instance(): def get_reddit_instance():
"""Initializes and returns a PRAW Reddit instance.""" # (This function is unchanged)
client_id = os.getenv("REDDIT_CLIENT_ID") client_id = os.getenv("REDDIT_CLIENT_ID")
client_secret = os.getenv("REDDIT_CLIENT_SECRET") client_secret = os.getenv("REDDIT_CLIENT_SECRET")
user_agent = os.getenv("REDDIT_USER_AGENT") user_agent = os.getenv("REDDIT_USER_AGENT")
@@ -50,41 +49,54 @@ def get_reddit_instance():
if not all([client_id, client_secret, user_agent]): if not all([client_id, client_secret, user_agent]):
print("Error: Reddit API credentials not found in .env file.") print("Error: Reddit API credentials not found in .env file.")
return None return None
return praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)
return praw.Reddit(
client_id=client_id,
client_secret=client_secret,
user_agent=user_agent
)
def scan_subreddits(reddit, subreddits_list, post_limit=25): def scan_subreddits(reddit, subreddits_list, post_limit=25):
"""Scans subreddits for stock tickers and returns a count of each.""" """Scans subreddits, stores mentions, and updates market caps in the database."""
all_tickers = Counter() conn = database.get_db_connection()
print(f"\nScanning {len(subreddits_list)} subreddits for top {post_limit} posts...") print(f"\nScanning {len(subreddits_list)} subreddits for top {post_limit} posts...")
for subreddit_name in subreddits_list: for subreddit_name in subreddits_list:
try: try:
subreddit_id = database.get_or_create_entity(conn, 'subreddits', 'name', subreddit_name)
subreddit = reddit.subreddit(subreddit_name) subreddit = reddit.subreddit(subreddit_name)
print(f"r/{subreddit_name}...") print(f"Scanning r/{subreddit_name}...")
# Fetch hot posts from the subreddit
for submission in subreddit.hot(limit=post_limit): for submission in subreddit.hot(limit=post_limit):
# Combine title and selftext for analysis
full_text = submission.title + " " + submission.selftext full_text = submission.title + " " + submission.selftext
# Extract tickers from the combined text
tickers_in_post = extract_tickers(full_text) tickers_in_post = extract_tickers(full_text)
all_tickers.update(tickers_in_post)
# Future work: also scan comments for ticker_symbol in set(tickers_in_post):
# submission.comments.replace_more(limit=0) # Expand all comment trees ticker_id = database.get_or_create_entity(conn, 'tickers', 'symbol', ticker_symbol)
# for comment in submission.comments.list():
# tickers_in_comment = extract_tickers(comment.body) database.add_mention(
# all_tickers.update(tickers_in_comment) conn,
ticker_id=ticker_id,
subreddit_id=subreddit_id,
post_id=submission.id,
timestamp=int(submission.created_utc)
)
# --- Check if market cap needs updating ---
ticker_info = database.get_ticker_info(conn, ticker_id)
current_time = int(time.time())
if not ticker_info['last_updated'] or (current_time - ticker_info['last_updated'] > MARKET_CAP_REFRESH_INTERVAL):
print(f" -> Fetching market cap for {ticker_symbol}...")
market_cap = get_market_cap(ticker_symbol)
if market_cap:
database.update_ticker_market_cap(conn, ticker_id, market_cap)
else:
# If fetch fails, still update the timestamp so we don't try again for 24 hours
database.update_ticker_market_cap(conn, ticker_id, ticker_info['market_cap']) # Keep old value
except Exception as e: except Exception as e:
print(f"Could not scan r/{subreddit_name}. Error: {e}") print(f"Could not scan r/{subreddit_name}. Error: {e}")
return all_tickers conn.close()
print("\n--- Scan Complete ---")
def main(): def main():
"""Main function to run the Reddit stock analysis tool.""" """Main function to run the Reddit stock analysis tool."""
@@ -92,28 +104,21 @@ def main():
parser.add_argument("config_file", help="Path to the JSON file containing subreddits.") parser.add_argument("config_file", help="Path to the JSON file containing subreddits.")
args = parser.parse_args() args = parser.parse_args()
# --- Part 1: Load Configuration & Initialize Reddit --- # --- Part 1: Initialize ---
database.initialize_db()
subreddits = load_subreddits(args.config_file) subreddits = load_subreddits(args.config_file)
if not subreddits: if not subreddits: return
return
reddit = get_reddit_instance() reddit = get_reddit_instance()
if not reddit: if not reddit: return
return
# --- Part 2: Scan Reddit for Tickers --- # --- Part 2: Scan and Store ---
ticker_counts = scan_subreddits(reddit, subreddits) scan_subreddits(reddit, subreddits)
if not ticker_counts:
print("No tickers found.")
return
print("\n--- Scan Complete ---") # --- Part 3: Generate and Display Report ---
print("Top 15 mentioned tickers:") database.generate_summary_report()
# --- Part 3: Display Results ---
# We will enrich this data with market cap and sentiment in the next steps
for ticker, count in ticker_counts.most_common(15):
print(f"{ticker}: {count} mentions")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -11,11 +11,17 @@ COMMON_WORDS_BLACKLIST = {
"WAY", "WHO", "WHY", "BIG", "BUY", "SELL", "HOLD", "BE", "GO", "WAY", "WHO", "WHY", "BIG", "BUY", "SELL", "HOLD", "BE", "GO",
"ON", "AT", "IN", "IS", "IT", "OF", "OR", "TO", "WE", "UP", "ON", "AT", "IN", "IS", "IT", "OF", "OR", "TO", "WE", "UP",
"OUT", "SO", "RH", "SEC", "IRS", "USA", "UK", "EU", "OUT", "SO", "RH", "SEC", "IRS", "USA", "UK", "EU",
"AI", "ML", "AR", "VR", "NFT", "DAO", "WEB3", "ETH", "BTC", "AI", "ML", "AR", "VR", "NFT", "DAO", "WEB3", "ETH", "BTC", "DOGE",
"USD", "EUR", "GBP", "JPY", "CNY", "INR", "AUD", "CAD", "CHF", "USD", "EUR", "GBP", "JPY", "CNY", "INR", "AUD", "CAD", "CHF",
"RUB", "ZAR", "BRL", "MXN", "HKD", "SGD", "NZD", "RSD", "RUB", "ZAR", "BRL", "MXN", "HKD", "SGD", "NZD", "RSD",
"JPY", "KRW", "SEK", "NOK", "DKK", "PLN", "CZK", "HUF", "TRY", "JPY", "KRW", "SEK", "NOK", "DKK", "PLN", "CZK", "HUF", "TRY",
"US", "IRA", "FDA", "SEC", "FBI", "CIA", "NSA", "NATO", "US", "IRA", "FDA", "SEC", "FBI", "CIA", "NSA", "NATO", "FINRA",
"NASDAQ", "NYSE", "AMEX", "FTSE", "DAX", "WSB", "SPX", "DJIA",
"EDGAR", "GDP", "CPI", "PPI", "PMI", "ISM", "FOMC", "ECB", "BOE",
"BOJ", "RBA", "RBNZ", "BIS", "NFA", "P", "VOO", "CTB", "DR",
"ETF", "EV", "ESG", "REIT", "SPAC", "IPO", "M&A", "LBO",
"Q1", "Q2", "Q3", "Q4", "FY", "FAQ", "ROI", "ROE", "EPS", "P/E", "PEG",
"FRG", "FXAIX", "FXIAX", "FZROX"
} }
def extract_tickers(text): def extract_tickers(text):