document_service.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import re
  2. from datetime import datetime, timezone
  3. from sqlalchemy import select
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from app.core.exceptions import ContentTooLargeError, DocumentNotFoundError
  6. from app.models.document import Document
  7. from app.schemas.document import CreateDocumentRequest, UpdateDocumentRequest
  8. CONTENT_MAX_BYTES = 200_000 # 200KB
  9. class DocumentService:
  10. def __init__(self, db: AsyncSession) -> None:
  11. self.db = db
  12. # ------------------------------------------------------------------ #
  13. # CREATE
  14. # ------------------------------------------------------------------ #
  15. async def create_document(
  16. self,
  17. data: CreateDocumentRequest,
  18. user_id: str | None = None,
  19. ) -> Document:
  20. if len(data.content.encode("utf-8")) > CONTENT_MAX_BYTES:
  21. raise ContentTooLargeError(len(data.content.encode("utf-8")))
  22. doc = Document(
  23. title=data.title,
  24. content=data.content,
  25. format=data.format,
  26. session_id=data.session_id,
  27. template_id=data.template_id,
  28. created_by=user_id,
  29. )
  30. self.db.add(doc)
  31. await self.db.commit()
  32. await self.db.refresh(doc)
  33. return doc
  34. # ------------------------------------------------------------------ #
  35. # READ
  36. # ------------------------------------------------------------------ #
  37. async def get_document(self, document_id: str) -> Document:
  38. result = await self.db.execute(
  39. select(Document).where(Document.id == document_id)
  40. )
  41. doc = result.scalar_one_or_none()
  42. if doc is None:
  43. raise DocumentNotFoundError(document_id)
  44. return doc
  45. async def list_documents(
  46. self,
  47. page: int = 1,
  48. page_size: int = 20,
  49. session_id: str | None = None,
  50. sort_by: str = "updated_at",
  51. sort_order: str = "desc",
  52. ) -> tuple[list[Document], int]:
  53. query = select(Document)
  54. if session_id:
  55. query = query.where(Document.session_id == session_id)
  56. # 排序
  57. sort_col = getattr(Document, sort_by, Document.updated_at)
  58. if sort_order == "asc":
  59. query = query.order_by(sort_col.asc())
  60. else:
  61. query = query.order_by(sort_col.desc())
  62. # 总数
  63. count_result = await self.db.execute(
  64. query.with_only_columns(Document.id)
  65. )
  66. total = len(count_result.all())
  67. # 分页
  68. offset = (page - 1) * page_size
  69. result = await self.db.execute(query.offset(offset).limit(page_size))
  70. docs = list(result.scalars().all())
  71. return docs, total
  72. # ------------------------------------------------------------------ #
  73. # UPDATE
  74. # ------------------------------------------------------------------ #
  75. async def update_document(
  76. self, document_id: str, data: UpdateDocumentRequest
  77. ) -> Document:
  78. doc = await self.get_document(document_id)
  79. if data.title is not None:
  80. doc.title = data.title
  81. if data.content is not None:
  82. if len(data.content.encode("utf-8")) > CONTENT_MAX_BYTES:
  83. raise ContentTooLargeError(len(data.content.encode("utf-8")))
  84. doc.content = data.content
  85. if data.blocks is not None:
  86. doc.content = self._apply_block_updates(doc.content, data.blocks)
  87. doc.updated_at = datetime.now(timezone.utc)
  88. await self.db.commit()
  89. await self.db.refresh(doc)
  90. return doc
  91. # ------------------------------------------------------------------ #
  92. # DELETE
  93. # ------------------------------------------------------------------ #
  94. async def delete_document(self, document_id: str) -> None:
  95. doc = await self.get_document(document_id)
  96. await self.db.delete(doc)
  97. await self.db.commit()
  98. # ------------------------------------------------------------------ #
  99. # 局部块更新:按 level + index 定位并替换对应的标题块
  100. # ------------------------------------------------------------------ #
  101. @staticmethod
  102. def _apply_block_updates(content: str, blocks: list) -> str:
  103. """
  104. 将文档按标题行拆分成若干块,按 level+index 替换对应块后重新拼接。
  105. 块的定义:以标题行(# / ## / ...)为分割点,
  106. 每个标题及其下属正文构成一个块。
  107. """
  108. lines = content.split("\n")
  109. # 找出所有标题行的位置
  110. heading_pattern = re.compile(r"^(#{1,6})\s+")
  111. heading_positions: list[tuple[int, int]] = [] # (line_index, level)
  112. for i, line in enumerate(lines):
  113. m = heading_pattern.match(line)
  114. if m:
  115. level = len(m.group(1))
  116. heading_positions.append((i, level))
  117. # 统计每个 level 已出现的次数,得到 index
  118. level_counter: dict[int, int] = {}
  119. heading_info: list[tuple[int, int, int]] = [] # (line_idx, level, index)
  120. for line_idx, level in heading_positions:
  121. idx = level_counter.get(level, 0)
  122. heading_info.append((line_idx, level, idx))
  123. level_counter[level] = idx + 1
  124. # 构建查找字典 (level, index) → line_idx
  125. heading_map: dict[tuple[int, int], int] = {
  126. (level, index): line_idx
  127. for line_idx, level, index in heading_info
  128. }
  129. # 将 lines 拆成块
  130. # 块边界 = 各标题行的 line_idx
  131. split_points = sorted({line_idx for line_idx, _, _ in heading_info})
  132. split_points.append(len(lines)) # 末尾哨兵
  133. # 前置正文(首个标题之前的内容)
  134. pre_content_end = split_points[0] if split_points else len(lines)
  135. chunks: list[str] = []
  136. chunks.append("\n".join(lines[:pre_content_end]))
  137. # 各标题块
  138. block_keys: list[tuple[int, int] | None] = [None] # 对应 pre_content
  139. for i, sp in enumerate(split_points[:-1]):
  140. end = split_points[i + 1]
  141. chunks.append("\n".join(lines[sp:end]))
  142. _, level, index = heading_info[i]
  143. block_keys.append((level, index))
  144. # 执行替换
  145. for block_update in blocks:
  146. key = (block_update.level, block_update.index)
  147. if key in heading_map:
  148. chunk_idx = block_keys.index(key)
  149. chunks[chunk_idx] = block_update.content
  150. return "\n".join(chunks).strip() + "\n"