document_service.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import re
  2. from datetime import datetime, timezone
  3. from sqlalchemy import select
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from app.core.exceptions import ContentTooLargeError, DocumentNotFoundError
  6. from app.models.document import Document
  7. from app.schemas.document import CreateDocumentRequest, UpdateDocumentRequest
  8. CONTENT_MAX_BYTES = 200_000 # 200KB
  9. class DocumentService:
  10. def __init__(self, db: AsyncSession) -> None:
  11. self.db = db
  12. # ------------------------------------------------------------------ #
  13. # CREATE
  14. # ------------------------------------------------------------------ #
  15. async def create_document(
  16. self,
  17. data: CreateDocumentRequest,
  18. user_id: str | None = None,
  19. ) -> Document:
  20. if len(data.content.encode("utf-8")) > CONTENT_MAX_BYTES:
  21. raise ContentTooLargeError(len(data.content.encode("utf-8")))
  22. doc = Document(
  23. title=data.title,
  24. content=data.content,
  25. format=data.format,
  26. session_id=data.session_id,
  27. template_id=data.template_id,
  28. source=data.source,
  29. created_by=user_id,
  30. )
  31. self.db.add(doc)
  32. await self.db.commit()
  33. await self.db.refresh(doc)
  34. return doc
  35. # ------------------------------------------------------------------ #
  36. # READ
  37. # ------------------------------------------------------------------ #
  38. async def get_document(self, document_id: str) -> Document:
  39. result = await self.db.execute(
  40. select(Document).where(Document.id == document_id)
  41. )
  42. doc = result.scalar_one_or_none()
  43. if doc is None:
  44. raise DocumentNotFoundError(document_id)
  45. return doc
  46. async def list_documents(
  47. self,
  48. page: int = 1,
  49. page_size: int = 20,
  50. session_id: str | None = None,
  51. source: str | None = None,
  52. sort_by: str = "updated_at",
  53. sort_order: str = "desc",
  54. ) -> tuple[list[Document], int]:
  55. query = select(Document)
  56. if session_id:
  57. query = query.where(Document.session_id == session_id)
  58. if source:
  59. query = query.where(Document.source == source)
  60. # 排序
  61. sort_col = getattr(Document, sort_by, Document.updated_at)
  62. if sort_order == "asc":
  63. query = query.order_by(sort_col.asc())
  64. else:
  65. query = query.order_by(sort_col.desc())
  66. # 总数
  67. count_result = await self.db.execute(
  68. query.with_only_columns(Document.id)
  69. )
  70. total = len(count_result.all())
  71. # 分页
  72. offset = (page - 1) * page_size
  73. result = await self.db.execute(query.offset(offset).limit(page_size))
  74. docs = list(result.scalars().all())
  75. return docs, total
  76. # ------------------------------------------------------------------ #
  77. # UPDATE
  78. # ------------------------------------------------------------------ #
  79. async def update_document(
  80. self, document_id: str, data: UpdateDocumentRequest
  81. ) -> Document:
  82. doc = await self.get_document(document_id)
  83. if data.title is not None:
  84. doc.title = data.title
  85. if data.content is not None:
  86. if len(data.content.encode("utf-8")) > CONTENT_MAX_BYTES:
  87. raise ContentTooLargeError(len(data.content.encode("utf-8")))
  88. doc.content = data.content
  89. if data.blocks is not None:
  90. doc.content = self._apply_block_updates(doc.content, data.blocks)
  91. doc.updated_at = datetime.now(timezone.utc)
  92. await self.db.commit()
  93. await self.db.refresh(doc)
  94. return doc
  95. # ------------------------------------------------------------------ #
  96. # DELETE
  97. # ------------------------------------------------------------------ #
  98. async def delete_document(self, document_id: str) -> None:
  99. doc = await self.get_document(document_id)
  100. await self.db.delete(doc)
  101. await self.db.commit()
  102. # ------------------------------------------------------------------ #
  103. # 局部块更新:按 level + index 定位并替换对应的标题块
  104. # ------------------------------------------------------------------ #
  105. @staticmethod
  106. def _apply_block_updates(content: str, blocks: list) -> str:
  107. """
  108. 将文档按标题行拆分成若干块,按 level+index 替换对应块后重新拼接。
  109. 块的定义:以标题行(# / ## / ...)为分割点,
  110. 每个标题及其下属正文构成一个块。
  111. """
  112. lines = content.split("\n")
  113. # 找出所有标题行的位置
  114. heading_pattern = re.compile(r"^(#{1,6})\s+")
  115. heading_positions: list[tuple[int, int]] = [] # (line_index, level)
  116. for i, line in enumerate(lines):
  117. m = heading_pattern.match(line)
  118. if m:
  119. level = len(m.group(1))
  120. heading_positions.append((i, level))
  121. # 统计每个 level 已出现的次数,得到 index
  122. level_counter: dict[int, int] = {}
  123. heading_info: list[tuple[int, int, int]] = [] # (line_idx, level, index)
  124. for line_idx, level in heading_positions:
  125. idx = level_counter.get(level, 0)
  126. heading_info.append((line_idx, level, idx))
  127. level_counter[level] = idx + 1
  128. # 构建查找字典 (level, index) → line_idx
  129. heading_map: dict[tuple[int, int], int] = {
  130. (level, index): line_idx
  131. for line_idx, level, index in heading_info
  132. }
  133. # 将 lines 拆成块
  134. # 块边界 = 各标题行的 line_idx
  135. split_points = sorted({line_idx for line_idx, _, _ in heading_info})
  136. split_points.append(len(lines)) # 末尾哨兵
  137. # 前置正文(首个标题之前的内容)
  138. pre_content_end = split_points[0] if split_points else len(lines)
  139. chunks: list[str] = []
  140. chunks.append("\n".join(lines[:pre_content_end]))
  141. # 各标题块
  142. block_keys: list[tuple[int, int] | None] = [None] # 对应 pre_content
  143. for i, sp in enumerate(split_points[:-1]):
  144. end = split_points[i + 1]
  145. chunks.append("\n".join(lines[sp:end]))
  146. _, level, index = heading_info[i]
  147. block_keys.append((level, index))
  148. # 执行替换
  149. for block_update in blocks:
  150. key = (block_update.level, block_update.index)
  151. if key in heading_map:
  152. chunk_idx = block_keys.index(key)
  153. chunks[chunk_idx] = block_update.content
  154. return "\n".join(chunks).strip() + "\n"