import re from datetime import datetime, timezone from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.exceptions import ContentTooLargeError, DocumentNotFoundError from app.models.document import Document from app.schemas.document import CreateDocumentRequest, UpdateDocumentRequest CONTENT_MAX_BYTES = 200_000 # 200KB class DocumentService: def __init__(self, db: AsyncSession) -> None: self.db = db # ------------------------------------------------------------------ # # CREATE # ------------------------------------------------------------------ # async def create_document( self, data: CreateDocumentRequest, user_id: str | None = None, ) -> Document: if len(data.content.encode("utf-8")) > CONTENT_MAX_BYTES: raise ContentTooLargeError(len(data.content.encode("utf-8"))) doc = Document( title=data.title, content=data.content, format=data.format, session_id=data.session_id, template_id=data.template_id, source=data.source, created_by=user_id, ) self.db.add(doc) await self.db.commit() await self.db.refresh(doc) return doc # ------------------------------------------------------------------ # # READ # ------------------------------------------------------------------ # async def get_document(self, document_id: str) -> Document: result = await self.db.execute( select(Document).where(Document.id == document_id) ) doc = result.scalar_one_or_none() if doc is None: raise DocumentNotFoundError(document_id) return doc async def list_documents( self, page: int = 1, page_size: int = 20, session_id: str | None = None, source: str | None = None, sort_by: str = "updated_at", sort_order: str = "desc", ) -> tuple[list[Document], int]: query = select(Document) if session_id: query = query.where(Document.session_id == session_id) if source: query = query.where(Document.source == source) # 排序 sort_col = getattr(Document, sort_by, Document.updated_at) if sort_order == "asc": query = query.order_by(sort_col.asc()) else: query = query.order_by(sort_col.desc()) # 总数 count_result = await self.db.execute( query.with_only_columns(Document.id) ) total = len(count_result.all()) # 分页 offset = (page - 1) * page_size result = await self.db.execute(query.offset(offset).limit(page_size)) docs = list(result.scalars().all()) return docs, total # ------------------------------------------------------------------ # # UPDATE # ------------------------------------------------------------------ # async def update_document( self, document_id: str, data: UpdateDocumentRequest ) -> Document: doc = await self.get_document(document_id) if data.title is not None: doc.title = data.title if data.content is not None: if len(data.content.encode("utf-8")) > CONTENT_MAX_BYTES: raise ContentTooLargeError(len(data.content.encode("utf-8"))) doc.content = data.content if data.blocks is not None: doc.content = self._apply_block_updates(doc.content, data.blocks) doc.updated_at = datetime.now(timezone.utc) await self.db.commit() await self.db.refresh(doc) return doc # ------------------------------------------------------------------ # # DELETE # ------------------------------------------------------------------ # async def delete_document(self, document_id: str) -> None: doc = await self.get_document(document_id) await self.db.delete(doc) await self.db.commit() # ------------------------------------------------------------------ # # 局部块更新:按 level + index 定位并替换对应的标题块 # ------------------------------------------------------------------ # @staticmethod def _apply_block_updates(content: str, blocks: list) -> str: """ 将文档按标题行拆分成若干块,按 level+index 替换对应块后重新拼接。 块的定义:以标题行(# / ## / ...)为分割点, 每个标题及其下属正文构成一个块。 """ lines = content.split("\n") # 找出所有标题行的位置 heading_pattern = re.compile(r"^(#{1,6})\s+") heading_positions: list[tuple[int, int]] = [] # (line_index, level) for i, line in enumerate(lines): m = heading_pattern.match(line) if m: level = len(m.group(1)) heading_positions.append((i, level)) # 统计每个 level 已出现的次数,得到 index level_counter: dict[int, int] = {} heading_info: list[tuple[int, int, int]] = [] # (line_idx, level, index) for line_idx, level in heading_positions: idx = level_counter.get(level, 0) heading_info.append((line_idx, level, idx)) level_counter[level] = idx + 1 # 构建查找字典 (level, index) → line_idx heading_map: dict[tuple[int, int], int] = { (level, index): line_idx for line_idx, level, index in heading_info } # 将 lines 拆成块 # 块边界 = 各标题行的 line_idx split_points = sorted({line_idx for line_idx, _, _ in heading_info}) split_points.append(len(lines)) # 末尾哨兵 # 前置正文(首个标题之前的内容) pre_content_end = split_points[0] if split_points else len(lines) chunks: list[str] = [] chunks.append("\n".join(lines[:pre_content_end])) # 各标题块 block_keys: list[tuple[int, int] | None] = [None] # 对应 pre_content for i, sp in enumerate(split_points[:-1]): end = split_points[i + 1] chunks.append("\n".join(lines[sp:end])) _, level, index = heading_info[i] block_keys.append((level, index)) # 执行替换 for block_update in blocks: key = (block_update.level, block_update.index) if key in heading_map: chunk_idx = block_keys.index(key) chunks[chunk_idx] = block_update.content return "\n".join(chunks).strip() + "\n"