""" TemplateSearchService centralizes query construction for templates search and keyword management, keeping API endpoints thin and consistent. Adds best-effort caching using Redis when available with an in-memory fallback. Cache keys are built from normalized query params. """ from __future__ import annotations from typing import List, Optional, Tuple, Dict, Any import json import time import threading from sqlalchemy import func, or_, exists from sqlalchemy.orm import Session from app.models.templates import DocumentTemplate, DocumentTemplateVersion, TemplateKeyword from app.services.cache import cache_get_json, cache_set_json, invalidate_prefix class TemplateSearchService: _mem_cache: Dict[str, Tuple[float, Any]] = {} _mem_lock = threading.RLock() _SEARCH_TTL_SECONDS = 60 # Fallback TTL _CATEGORIES_TTL_SECONDS = 120 # Fallback TTL def __init__(self, db: Session) -> None: self.db = db async def search_templates( self, *, q: Optional[str], categories: Optional[List[str]], keywords: Optional[List[str]], keywords_mode: str, has_keywords: Optional[bool], skip: int, limit: int, sort_by: str, sort_dir: str, active_only: bool, include_total: bool, ) -> Tuple[List[Dict[str, Any]], Optional[int]]: # Build normalized cache key parts norm_categories = sorted({c for c in (categories or []) if c}) or None norm_keywords = sorted({(kw or "").strip().lower() for kw in (keywords or []) if kw and kw.strip()}) or None norm_mode = (keywords_mode or "any").lower() if norm_mode not in ("any", "all"): norm_mode = "any" norm_sort_by = (sort_by or "name").lower() if norm_sort_by not in ("name", "category", "updated"): norm_sort_by = "name" norm_sort_dir = (sort_dir or "asc").lower() if norm_sort_dir not in ("asc", "desc"): norm_sort_dir = "asc" parts = { "q": q or "", "categories": norm_categories, "keywords": norm_keywords, "keywords_mode": norm_mode, "has_keywords": has_keywords, "skip": int(skip), "limit": int(limit), "sort_by": norm_sort_by, "sort_dir": norm_sort_dir, "active_only": bool(active_only), "include_total": bool(include_total), } # Try cache first (local then adaptive) cached = self._cache_get_local("templates", parts) if cached is None: try: from app.services.adaptive_cache import adaptive_cache_get cached = await adaptive_cache_get( cache_type="templates", cache_key="template_search", parts=parts ) except Exception: cached = await self._cache_get_redis("templates", parts) if cached is not None: return cached["items"], cached.get("total") query = self.db.query(DocumentTemplate) if active_only: query = query.filter(DocumentTemplate.active == True) # noqa: E712 if q: like = f"%{q}%" query = query.filter( or_( DocumentTemplate.name.ilike(like), DocumentTemplate.description.ilike(like), ) ) if norm_categories: query = query.filter(DocumentTemplate.category.in_(norm_categories)) if norm_keywords: query = query.join(TemplateKeyword, TemplateKeyword.template_id == DocumentTemplate.id) if norm_mode == "any": query = query.filter(TemplateKeyword.keyword.in_(norm_keywords)).distinct() else: query = query.filter(TemplateKeyword.keyword.in_(norm_keywords)) query = query.group_by(DocumentTemplate.id) query = query.having(func.count(func.distinct(TemplateKeyword.keyword)) == len(norm_keywords)) if has_keywords is not None: kw_exists = exists().where(TemplateKeyword.template_id == DocumentTemplate.id) if has_keywords: query = query.filter(kw_exists) else: query = query.filter(~kw_exists) if norm_sort_by == "name": order_col = DocumentTemplate.name elif norm_sort_by == "category": order_col = DocumentTemplate.category else: order_col = func.coalesce(DocumentTemplate.updated_at, DocumentTemplate.created_at) if norm_sort_dir == "asc": query = query.order_by(order_col.asc()) else: query = query.order_by(order_col.desc()) total = query.count() if include_total else None templates: List[DocumentTemplate] = query.offset(skip).limit(limit).all() # Resolve latest version semver for current_version_id in bulk current_ids = [t.current_version_id for t in templates if t.current_version_id] latest_by_version_id: dict[int, str] = {} if current_ids: rows = ( self.db.query(DocumentTemplateVersion.id, DocumentTemplateVersion.semantic_version) .filter(DocumentTemplateVersion.id.in_(current_ids)) .all() ) latest_by_version_id = {row[0]: row[1] for row in rows} items: List[Dict[str, Any]] = [] for tpl in templates: latest_version = latest_by_version_id.get(int(tpl.current_version_id)) if tpl.current_version_id else None items.append({ "id": tpl.id, "name": tpl.name, "category": tpl.category, "active": tpl.active, "latest_version": latest_version, }) payload = {"items": items, "total": total} # Store in caches (best-effort) self._cache_set_local("templates", parts, payload, self._SEARCH_TTL_SECONDS) try: from app.services.adaptive_cache import adaptive_cache_set await adaptive_cache_set( cache_type="templates", cache_key="template_search", value=payload, parts=parts ) except Exception: await self._cache_set_redis("templates", parts, payload, self._SEARCH_TTL_SECONDS) return items, total async def list_categories(self, *, active_only: bool) -> List[tuple[Optional[str], int]]: parts = {"active_only": bool(active_only)} cached = self._cache_get_local("templates_categories", parts) if cached is None: cached = await self._cache_get_redis("templates_categories", parts) if cached is not None: items = cached.get("items") or [] return [(row[0], row[1]) for row in items] query = self.db.query(DocumentTemplate.category, func.count(DocumentTemplate.id).label("count")) if active_only: query = query.filter(DocumentTemplate.active == True) # noqa: E712 rows = query.group_by(DocumentTemplate.category).order_by(DocumentTemplate.category.asc()).all() items = [(row[0], row[1]) for row in rows] payload = {"items": items} self._cache_set_local("templates_categories", parts, payload, self._CATEGORIES_TTL_SECONDS) await self._cache_set_redis("templates_categories", parts, payload, self._CATEGORIES_TTL_SECONDS) return items def list_keywords(self, template_id: int) -> List[str]: _ = self._get_template_or_404(template_id) rows = ( self.db.query(TemplateKeyword) .filter(TemplateKeyword.template_id == template_id) .order_by(TemplateKeyword.keyword.asc()) .all() ) return [r.keyword for r in rows] async def add_keywords(self, template_id: int, keywords: List[str]) -> List[str]: _ = self._get_template_or_404(template_id) to_add = [] for kw in (keywords or []): normalized = (kw or "").strip().lower() if not normalized: continue exists_row = ( self.db.query(TemplateKeyword) .filter(TemplateKeyword.template_id == template_id, TemplateKeyword.keyword == normalized) .first() ) if not exists_row: to_add.append(TemplateKeyword(template_id=template_id, keyword=normalized)) if to_add: self.db.add_all(to_add) self.db.commit() # Invalidate caches affected by keyword changes await self.invalidate_all() return self.list_keywords(template_id) async def remove_keyword(self, template_id: int, keyword: str) -> List[str]: _ = self._get_template_or_404(template_id) normalized = (keyword or "").strip().lower() if normalized: self.db.query(TemplateKeyword).filter( TemplateKeyword.template_id == template_id, TemplateKeyword.keyword == normalized, ).delete(synchronize_session=False) self.db.commit() await self.invalidate_all() return self.list_keywords(template_id) def _get_template_or_404(self, template_id: int) -> DocumentTemplate: # Local import to avoid circular from app.services.template_service import get_template_or_404 as _get return _get(self.db, template_id) # ---- Cache helpers ---- @classmethod def _build_mem_key(cls, kind: str, parts: dict) -> str: # Deterministic key return f"search:{kind}:v1:{json.dumps(parts, sort_keys=True, separators=(",", ":"))}" @classmethod def _cache_get_local(cls, kind: str, parts: dict) -> Optional[dict]: key = cls._build_mem_key(kind, parts) now = time.time() with cls._mem_lock: entry = cls._mem_cache.get(key) if not entry: return None expires_at, value = entry if expires_at <= now: try: del cls._mem_cache[key] except Exception: pass return None return value @classmethod def _cache_set_local(cls, kind: str, parts: dict, value: dict, ttl_seconds: int) -> None: key = cls._build_mem_key(kind, parts) expires_at = time.time() + max(1, int(ttl_seconds)) with cls._mem_lock: cls._mem_cache[key] = (expires_at, value) @staticmethod async def _cache_get_redis(kind: str, parts: dict) -> Optional[dict]: try: return await cache_get_json(kind, None, parts) except Exception: return None @staticmethod async def _cache_set_redis(kind: str, parts: dict, value: dict, ttl_seconds: int) -> None: try: await cache_set_json(kind, None, parts, value, ttl_seconds) except Exception: return @classmethod async def invalidate_all(cls) -> None: # Clear in-memory with cls._mem_lock: cls._mem_cache.clear() # Best-effort Redis invalidation try: await invalidate_prefix("search:templates:") await invalidate_prefix("search:templates_categories:") except Exception: pass # Helper to run async cache calls from sync context def asyncio_run(aw): # type: ignore # Not used anymore; kept for backward compatibility if imported elsewhere try: import asyncio return asyncio.run(aw) except Exception: return None