# main.py import argparse import json import os from collections import Counter import praw import yfinance as yf from dotenv import load_dotenv from ticker_extractor import extract_tickers # Load environment variables from .env file load_dotenv() def load_subreddits(filepath): """Loads a list of subreddits from a JSON file.""" try: with open(filepath, 'r') as f: data = json.load(f) return data.get("subreddits", []) except FileNotFoundError: print(f"Error: The file '{filepath}' was not found.") return None except json.JSONDecodeError: print(f"Error: Could not decode JSON from '{filepath}'.") return None def get_market_cap(ticker_symbol): """Fetches the market capitalization for a given stock ticker.""" try: ticker = yf.Ticker(ticker_symbol) market_cap = ticker.info.get('marketCap') if market_cap: # Formatting for better readability return f"${market_cap:,}" return "N/A" except Exception as e: # yfinance can sometimes fail for various reasons (e.g., invalid ticker) return "N/A" def get_reddit_instance(): """Initializes and returns a PRAW Reddit instance.""" client_id = os.getenv("REDDIT_CLIENT_ID") client_secret = os.getenv("REDDIT_CLIENT_SECRET") user_agent = os.getenv("REDDIT_USER_AGENT") if not all([client_id, client_secret, user_agent]): print("Error: Reddit API credentials not found in .env file.") return None return praw.Reddit( client_id=client_id, client_secret=client_secret, user_agent=user_agent ) def scan_subreddits(reddit, subreddits_list, post_limit=25): """Scans subreddits for stock tickers and returns a count of each.""" all_tickers = Counter() print(f"\nScanning {len(subreddits_list)} subreddits for top {post_limit} posts...") for subreddit_name in subreddits_list: try: subreddit = reddit.subreddit(subreddit_name) print(f"r/{subreddit_name}...") # Fetch hot posts from the subreddit for submission in subreddit.hot(limit=post_limit): # Combine title and selftext for analysis full_text = submission.title + " " + submission.selftext # Extract tickers from the combined text tickers_in_post = extract_tickers(full_text) all_tickers.update(tickers_in_post) # Future work: also scan comments # submission.comments.replace_more(limit=0) # Expand all comment trees # for comment in submission.comments.list(): # tickers_in_comment = extract_tickers(comment.body) # all_tickers.update(tickers_in_comment) except Exception as e: print(f"Could not scan r/{subreddit_name}. Error: {e}") return all_tickers def main(): """Main function to run the Reddit stock analysis tool.""" parser = argparse.ArgumentParser(description="Analyze stock ticker mentions on Reddit.") parser.add_argument("config_file", help="Path to the JSON file containing subreddits.") args = parser.parse_args() # --- Part 1: Load Configuration & Initialize Reddit --- subreddits = load_subreddits(args.config_file) if not subreddits: return reddit = get_reddit_instance() if not reddit: return # --- Part 2: Scan Reddit for Tickers --- ticker_counts = scan_subreddits(reddit, subreddits) if not ticker_counts: print("No tickers found.") return print("\n--- Scan Complete ---") print("Top 15 mentioned tickers:") # --- Part 3: Display Results --- # We will enrich this data with market cap and sentiment in the next steps for ticker, count in ticker_counts.most_common(15): print(f"{ticker}: {count} mentions") if __name__ == "__main__": main()