| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- import re
- from datetime import datetime, timezone
- from sqlalchemy import select
- from sqlalchemy.ext.asyncio import AsyncSession
- from app.core.exceptions import ContentTooLargeError, DocumentNotFoundError
- from app.models.document import Document
- from app.schemas.document import CreateDocumentRequest, UpdateDocumentRequest
- CONTENT_MAX_BYTES = 200_000 # 200KB
- class DocumentService:
- def __init__(self, db: AsyncSession) -> None:
- self.db = db
- # ------------------------------------------------------------------ #
- # CREATE
- # ------------------------------------------------------------------ #
- async def create_document(
- self,
- data: CreateDocumentRequest,
- user_id: str | None = None,
- ) -> Document:
- if len(data.content.encode("utf-8")) > CONTENT_MAX_BYTES:
- raise ContentTooLargeError(len(data.content.encode("utf-8")))
- doc = Document(
- title=data.title,
- content=data.content,
- format=data.format,
- session_id=data.session_id,
- template_id=data.template_id,
- source=data.source,
- created_by=user_id,
- )
- self.db.add(doc)
- await self.db.commit()
- await self.db.refresh(doc)
- return doc
- # ------------------------------------------------------------------ #
- # READ
- # ------------------------------------------------------------------ #
- async def get_document(self, document_id: str) -> Document:
- result = await self.db.execute(
- select(Document).where(Document.id == document_id)
- )
- doc = result.scalar_one_or_none()
- if doc is None:
- raise DocumentNotFoundError(document_id)
- return doc
- async def list_documents(
- self,
- page: int = 1,
- page_size: int = 20,
- session_id: str | None = None,
- source: str | None = None,
- sort_by: str = "updated_at",
- sort_order: str = "desc",
- ) -> tuple[list[Document], int]:
- query = select(Document)
- if session_id:
- query = query.where(Document.session_id == session_id)
- if source:
- query = query.where(Document.source == source)
- # 排序
- sort_col = getattr(Document, sort_by, Document.updated_at)
- if sort_order == "asc":
- query = query.order_by(sort_col.asc())
- else:
- query = query.order_by(sort_col.desc())
- # 总数
- count_result = await self.db.execute(
- query.with_only_columns(Document.id)
- )
- total = len(count_result.all())
- # 分页
- offset = (page - 1) * page_size
- result = await self.db.execute(query.offset(offset).limit(page_size))
- docs = list(result.scalars().all())
- return docs, total
- # ------------------------------------------------------------------ #
- # UPDATE
- # ------------------------------------------------------------------ #
- async def update_document(
- self, document_id: str, data: UpdateDocumentRequest
- ) -> Document:
- doc = await self.get_document(document_id)
- if data.title is not None:
- doc.title = data.title
- if data.content is not None:
- if len(data.content.encode("utf-8")) > CONTENT_MAX_BYTES:
- raise ContentTooLargeError(len(data.content.encode("utf-8")))
- doc.content = data.content
- if data.blocks is not None:
- doc.content = self._apply_block_updates(doc.content, data.blocks)
- doc.updated_at = datetime.now(timezone.utc)
- await self.db.commit()
- await self.db.refresh(doc)
- return doc
- # ------------------------------------------------------------------ #
- # DELETE
- # ------------------------------------------------------------------ #
- async def delete_document(self, document_id: str) -> None:
- doc = await self.get_document(document_id)
- await self.db.delete(doc)
- await self.db.commit()
- # ------------------------------------------------------------------ #
- # 局部块更新:按 level + index 定位并替换对应的标题块
- # ------------------------------------------------------------------ #
- @staticmethod
- def _apply_block_updates(content: str, blocks: list) -> str:
- """
- 将文档按标题行拆分成若干块,按 level+index 替换对应块后重新拼接。
- 块的定义:以标题行(# / ## / ...)为分割点,
- 每个标题及其下属正文构成一个块。
- """
- lines = content.split("\n")
- # 找出所有标题行的位置
- heading_pattern = re.compile(r"^(#{1,6})\s+")
- heading_positions: list[tuple[int, int]] = [] # (line_index, level)
- for i, line in enumerate(lines):
- m = heading_pattern.match(line)
- if m:
- level = len(m.group(1))
- heading_positions.append((i, level))
- # 统计每个 level 已出现的次数,得到 index
- level_counter: dict[int, int] = {}
- heading_info: list[tuple[int, int, int]] = [] # (line_idx, level, index)
- for line_idx, level in heading_positions:
- idx = level_counter.get(level, 0)
- heading_info.append((line_idx, level, idx))
- level_counter[level] = idx + 1
- # 构建查找字典 (level, index) → line_idx
- heading_map: dict[tuple[int, int], int] = {
- (level, index): line_idx
- for line_idx, level, index in heading_info
- }
- # 将 lines 拆成块
- # 块边界 = 各标题行的 line_idx
- split_points = sorted({line_idx for line_idx, _, _ in heading_info})
- split_points.append(len(lines)) # 末尾哨兵
- # 前置正文(首个标题之前的内容)
- pre_content_end = split_points[0] if split_points else len(lines)
- chunks: list[str] = []
- chunks.append("\n".join(lines[:pre_content_end]))
- # 各标题块
- block_keys: list[tuple[int, int] | None] = [None] # 对应 pre_content
- for i, sp in enumerate(split_points[:-1]):
- end = split_points[i + 1]
- chunks.append("\n".join(lines[sp:end]))
- _, level, index = heading_info[i]
- block_keys.append((level, index))
- # 执行替换
- for block_update in blocks:
- key = (block_update.level, block_update.index)
- if key in heading_map:
- chunk_idx = block_keys.index(key)
- chunks[chunk_idx] = block_update.content
- return "\n".join(chunks).strip() + "\n"
|