diff --git a/scraper/__init__.py b/scraper/__init__.py index 5df7bc2..2951812 100644 --- a/scraper/__init__.py +++ b/scraper/__init__.py @@ -1,5 +1,6 @@ """Reddit scraping module using Selenium for page-based scraping.""" from .selenium_scrapers import get_scraper, RedditScraper +from .cache import RequestCache, get_cache -__all__ = ["get_scraper", "RedditScraper"] +__all__ = ["get_scraper", "RedditScraper", "RequestCache", "get_cache"] diff --git a/scraper/cache.py b/scraper/cache.py new file mode 100644 index 0000000..ab31135 --- /dev/null +++ b/scraper/cache.py @@ -0,0 +1,94 @@ +"""Request caching module for Reddit Scraper.""" + +import hashlib +from functools import lru_cache +from typing import Dict, Any, Optional +from datetime import datetime, timedelta + + +class RequestCache: + """In-memory LRU cache with TTL for scraping results.""" + + def __init__(self, default_ttl_seconds: int = 300): + """ + Initialize the cache. + + Args: + default_ttl_seconds: Time-to-live for cached entries (default: 5 minutes) + """ + self.default_ttl = timedelta(seconds=default_ttl_seconds) + self._cache: Dict[str, Dict[str, Any]] = {} + + def _generate_key(self, **kwargs) -> str: + """Generate a cache key from query parameters.""" + # Sort kwargs to ensure consistent ordering + sorted_kwargs = sorted(kwargs.items()) + key_str = str(sorted_kwargs) + return hashlib.md5(key_str.encode()).hexdigest() + + def get(self, **kwargs) -> Optional[Dict[str, Any]]: + """ + Get cached result if still valid. + + Args: + kwargs: Query parameters to match + + Returns: + Cached result or None if not found/expired + """ + key = self._generate_key(**kwargs) + + if key not in self._cache: + return None + + entry = self._cache[key] + + # Check expiration + if datetime.now() > entry['expires_at']: + del self._cache[key] + return None + + return entry['result'] + + def set(self, result: Dict[str, Any], ttl_seconds: Optional[int] = None, **kwargs): + """ + Cache a scraping result. + + Args: + result: The scraping result to cache + ttl_seconds: Override default TTL (optional) + kwargs: Query parameters as key=value pairs + """ + key = self._generate_key(**kwargs) + + # Calculate expiration time (handle both int seconds and timedelta objects) + if isinstance(ttl_seconds, int): + ttl_delta = timedelta(seconds=ttl_seconds) + else: + ttl_delta = ttl_seconds or self.default_ttl + + expires_at = datetime.now() + ttl_delta + + self._cache[key] = { + 'result': result, + 'expires_at': expires_at, + 'cached_at': datetime.now() + } + + def clear(self): + """Clear all cached entries.""" + self._cache.clear() + + +# Global cache instance (shared across scraper instances) +_cache_instance: Optional[RequestCache] = None + + +def get_cache(ttl_seconds: int = 300) -> RequestCache: + """Get or create global cache instance.""" + global _cache_instance + + if _cache_instance is None: + _cache_instance = RequestCache(default_ttl_seconds=ttl_seconds) + + return _cache_instance diff --git a/scraper/selenium_scrapers.py b/scraper/selenium_scrapers.py index ec0f06f..f1d5108 100644 --- a/scraper/selenium_scrapers.py +++ b/scraper/selenium_scrapers.py @@ -6,6 +6,8 @@ from selenium import webdriver from selenium.webdriver.firefox.options import Options from selenium.webdriver.firefox.service import Service +from .cache import get_cache + class RedditScraper: """Scrapes OLD Reddit pages (old.reddit.com) using a headless Firefox browser.""" @@ -232,6 +234,17 @@ class RedditScraper: Returns: Dict containing scraped data or error information """ + # Check cache first (skip if comments requested, as they change frequently) + if not include_comments: + cached_result = get_cache().get( + subreddit=subreddit, + limit=limit, + time_range=time_range, + depth=depth + ) + if cached_result is not None: + return cached_result + self._ensure_browser() try: @@ -260,13 +273,20 @@ class RedditScraper: } posts.append(post_obj) - return { + # Build final result structure + result = { "subreddit": subreddit, "time_range": time_range, "limit": len(posts), "posts_count": len(posts), "data": posts } + + # Cache the result (only if no comments requested) + if not include_comments: + get_cache().set(result, subreddit=subreddit, limit=limit, time_range=time_range, depth=depth) + + return result except Exception as e: import traceback