Remove duplicate definitions, add functionality to fetch market data for a single stock, and print default values for command line options.

This commit is contained in:
2025-07-23 22:43:46 +02:00
parent 07c1fd3841
commit c9e754c9c9
2 changed files with 67 additions and 88 deletions

View File

@@ -9,72 +9,6 @@ from datetime import datetime, timedelta, timezone
DB_FILE = "reddit_stocks.db"
MARKET_CAP_REFRESH_INTERVAL = 86400
def get_db_connection():
"""Establishes a connection to the SQLite database."""
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
return conn
def initialize_db():
"""
Initializes the database and creates the necessary tables if they don't exist.
"""
conn = get_db_connection()
cursor = conn.cursor()
# --- Create tickers table ---
cursor.execute("""
CREATE TABLE IF NOT EXISTS tickers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL UNIQUE,
market_cap INTEGER,
closing_price REAL,
last_updated INTEGER
)
""")
# --- Create subreddits table ---
cursor.execute("""
CREATE TABLE IF NOT EXISTS subreddits (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE
)
""")
# --- Create mentions table ---
cursor.execute("""
CREATE TABLE IF NOT EXISTS mentions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ticker_id INTEGER,
subreddit_id INTEGER,
post_id TEXT NOT NULL,
mention_type TEXT NOT NULL,
mention_sentiment REAL, -- Renamed from sentiment_score for clarity
post_avg_sentiment REAL, -- NEW: Stores the avg sentiment of the whole post
mention_timestamp INTEGER NOT NULL,
FOREIGN KEY (ticker_id) REFERENCES tickers (id),
FOREIGN KEY (subreddit_id) REFERENCES subreddits (id)
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS posts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
post_id TEXT NOT NULL UNIQUE,
title TEXT NOT NULL,
post_url TEXT,
subreddit_id INTEGER,
post_timestamp INTEGER,
comment_count INTEGER,
avg_comment_sentiment REAL,
FOREIGN KEY (subreddit_id) REFERENCES subreddits (id)
)
""")
conn.commit()
conn.close()
log.info("Database initialized successfully.")
def clean_stale_tickers():
"""
Removes tickers and their associated mentions from the database
@@ -381,4 +315,16 @@ def get_all_tickers():
conn = get_db_connection()
results = conn.execute("SELECT id, symbol FROM tickers;").fetchall()
conn.close()
return results
return results
def get_ticker_by_symbol(symbol):
"""
Retrieves a single ticker's ID and symbol from the database.
The search is case-insensitive. Returns a Row object or None if not found.
"""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT id, symbol FROM tickers WHERE LOWER(symbol) = LOWER(?)", (symbol,))
result = cursor.fetchone()
conn.close()
return result