diff --git a/api/api/db/logic/process_schema.py b/api/api/db/logic/process_schema.py index 1bc6f51..c953eb8 100644 --- a/api/api/db/logic/process_schema.py +++ b/api/api/db/logic/process_schema.py @@ -114,6 +114,7 @@ async def get_process_schema_page_DTO( limit=limit, ) + async def get_process_schema_by_id(connection: AsyncConnection, id: int) -> Optional[ProcessSchema]: """ Получает process_schema по id. diff --git a/api/api/db/logic/ps_node.py b/api/api/db/logic/ps_node.py index 2aafa75..1b38b72 100644 --- a/api/api/db/logic/ps_node.py +++ b/api/api/db/logic/ps_node.py @@ -1,15 +1,15 @@ -from typing import Optional +from typing import Optional, List from datetime import datetime, timezone -from sqlalchemy import insert, select, desc +from sqlalchemy import insert, select, desc, and_, or_, delete, update from sqlalchemy.ext.asyncio import AsyncConnection -from orm.tables.process import ps_node_table +from orm.tables.process import ps_node_table, node_link_table, process_schema_table from api.schemas.process.ps_node import Ps_Node from model_nodes.node_listen_models import ListenNodeCoreSchema -from orm.tables.process import NodeStatus, NodeType +from orm.tables.process import NodeStatus async def get_ps_node_by_id(connection: AsyncConnection, id: int) -> Optional[Ps_Node]: @@ -27,21 +27,6 @@ async def get_ps_node_by_id(connection: AsyncConnection, id: int) -> Optional[Ps return Ps_Node.model_validate(ps_node_data) -async def get_ps_node_by_type_and_ps_id(connection: AsyncConnection, node_type: str, ps_id: int) -> Optional[Ps_Node]: - """ - Получает ps_node по node_type и ps_id. - """ - query = select(ps_node_table).where(ps_node_table.c.node_type == node_type, ps_node_table.c.ps_id == ps_id) - - ps_node_db_cursor = await connection.execute(query) - - ps_node_data = ps_node_db_cursor.mappings().one_or_none() - if not ps_node_data: - return None - - return Ps_Node.model_validate(ps_node_data) - - async def get_last_ps_node_by_creator_and_ps_id( connection: AsyncConnection, creator_id: int, ps_id: int ) -> Optional[Ps_Node]: @@ -86,3 +71,152 @@ async def create_ps_node_schema( await connection.commit() return await get_last_ps_node_by_creator_and_ps_id(connection, creator_id, validated_schema.ps_id) + + +async def check_node_connection(connection: AsyncConnection, node_id: int, next_node_id: int, port: int) -> bool: + """ + Проверяет, подключен ли next_node_id к node_id через указанный порт. + """ + query = select(node_link_table).where( + and_( + node_link_table.c.node_id == node_id, + node_link_table.c.next_node_id == next_node_id, + node_link_table.c.link_point_id == port, + ) + ) + + result = await connection.execute(query) + return result.mappings().first() is not None + + +async def get_all_child_nodes_with_depth(connection: AsyncConnection, node_id: int) -> List[tuple[Ps_Node, int]]: + """ + Рекурсивно находит ВСЕ дочерние узлы с их уровнем вложенности. + """ + all_child_nodes = [] + visited_nodes = set() + + async def find_children_with_depth(current_node_id: int, current_depth: int): + if current_node_id in visited_nodes: + return + + visited_nodes.add(current_node_id) + + query = ( + select(ps_node_table) + .join(node_link_table, ps_node_table.c.id == node_link_table.c.next_node_id) + .where(node_link_table.c.node_id == current_node_id) + ) + + result = await connection.execute(query) + child_nodes = result.mappings().all() + + for node_data in child_nodes: + node = Ps_Node.model_validate(node_data) + all_child_nodes.append((node, current_depth + 1)) + await find_children_with_depth(node.id, current_depth + 1) + + await find_children_with_depth(node_id, 0) + return all_child_nodes + + +async def get_nodes_for_deletion_ordered(connection: AsyncConnection, node_id: int) -> List[int]: + """ + Возвращает список ID узлов для удаления в правильном порядке: + от самых последних к первым. + """ + child_nodes_with_depth = await get_all_child_nodes_with_depth(connection, node_id) + + child_nodes_with_depth.sort(key=lambda x: x[1], reverse=True) + + ordered_node_ids = [node.id for node, depth in child_nodes_with_depth] + + ordered_node_ids.append(node_id) + + return ordered_node_ids + + +async def delete_ps_node_by_id_completely(connection: AsyncConnection, node_id: int) -> tuple[bool, str]: + """ + Полностью удаляет узел из базы данных по ID. + """ + try: + node_query = select(ps_node_table).where(ps_node_table.c.id == node_id) + node_result = await connection.execute(node_query) + node_data = node_result.mappings().first() + + if not node_data: + return False, "Node not found" + + ps_id = node_data["ps_id"] + + await connection.execute( + delete(node_link_table).where( + or_(node_link_table.c.node_id == node_id, node_link_table.c.next_node_id == node_id) + ) + ) + + await remove_node_from_process_schema_settings(connection, ps_id, node_id) + + result = await connection.execute(delete(ps_node_table).where(ps_node_table.c.id == node_id)) + + if result.rowcount > 0: + await connection.commit() + return True, "Success" + else: + await connection.rollback() + return False, "Node not found" + + except Exception as e: + await connection.rollback() + return False, str(e) + + +async def delete_ps_nodes_sequentially_with_error_handling( + connection: AsyncConnection, node_ids: List[int] +) -> List[int]: + """ + Поочередно удаляет узлы из базы данных. + Возвращает список успешно удаленных ID узлов. + Выбрасывает исключение при первой ошибке. + """ + successfully_deleted = [] + + for node_id in node_ids: + success, error_message = await delete_ps_node_by_id_completely(connection, node_id) + if success: + successfully_deleted.append(node_id) + else: + raise Exception(f"Failed to delete node {node_id}: {error_message}") + + return successfully_deleted + + +async def remove_node_from_process_schema_settings(connection: AsyncConnection, ps_id: int, node_id: int): + """ + Удаляет ноду из поля settings в таблице process_schema. + """ + from api.db.logic.process_schema import get_process_schema_by_id + + process_schema = await get_process_schema_by_id(connection, ps_id) + + if not process_schema or not process_schema.settings: + return + + settings = process_schema.settings + + if "nodes" in settings and isinstance(settings["nodes"], list): + settings["nodes"] = [ + node_item + for node_item in settings["nodes"] + if not ( + isinstance(node_item, dict) + and "node" in node_item + and isinstance(node_item["node"], dict) + and node_item["node"].get("id") == node_id + ) + ] + + await connection.execute( + update(process_schema_table).where(process_schema_table.c.id == ps_id).values(settings=settings) + ) diff --git a/api/api/endpoints/process_schema.py b/api/api/endpoints/process_schema.py index dffbbba..923b0ba 100644 --- a/api/api/endpoints/process_schema.py +++ b/api/api/endpoints/process_schema.py @@ -146,8 +146,6 @@ async def create_processschema_endpoint( validated_start_schema = start_node.validate() - print(validated_start_schema) - db_start_schema = await create_ps_node_schema(connection, validated_start_schema, user_validation.id) node = ProcessSchemaSettingsNode( diff --git a/api/api/endpoints/ps_node.py b/api/api/endpoints/ps_node.py index f4c750a..787497e 100644 --- a/api/api/endpoints/ps_node.py +++ b/api/api/endpoints/ps_node.py @@ -7,20 +7,30 @@ from api.db.logic.account import get_user_by_login from api.schemas.base import bearer_schema from api.schemas.process.process_schema import ProcessSchemaSettingsNodeLink, ProcessSchemaSettingsNode -from api.schemas.process.ps_node import Ps_NodeFrontResponseNode, Ps_NodeRequest +from api.schemas.process.ps_node import Ps_NodeFrontResponseNode, Ps_NodeRequest, Ps_NodeDeleteRequest from api.schemas.process.ps_node import Ps_NodeFrontResponse from api.services.auth import get_current_user -from api.db.logic.ps_node import create_ps_node_schema +from api.db.logic.ps_node import ( + create_ps_node_schema, + get_ps_node_by_id, + check_node_connection, + get_nodes_for_deletion_ordered, + delete_ps_nodes_sequentially_with_error_handling, +) from api.db.logic.node_link import get_last_link_name_by_node_id, create_node_link_schema -from api.db.logic.process_schema import update_process_schema_settings_by_id +from api.db.logic.process_schema import update_process_schema_settings_by_id, get_process_schema_by_id +from api.services.user_role_validation import ( + db_user_role_validation_for_list_events_and_process_schema_by_list_event_id, +) from core import VorkNodeRegistry, VorkNodeLink from model_nodes import VorkNodeLinkData from api.utils.to_camel_dict import to_camel_dict +from api.error import create_operation_error, create_access_error, create_validation_error, create_server_error api_router = APIRouter( @@ -29,6 +39,78 @@ api_router = APIRouter( ) +@api_router.delete("", dependencies=[Depends(bearer_schema)], status_code=status.HTTP_200_OK) +async def delete_ps_node_endpoint( + ps_node_delete_data: Ps_NodeDeleteRequest, + connection: AsyncConnection = Depends(get_connection_dep), + current_user=Depends(get_current_user), +): + process_schema = await get_process_schema_by_id(connection, ps_node_delete_data.schema_id) + if process_schema is None: + raise create_operation_error( + message="Process schema not found", + status_code=status.HTTP_404_NOT_FOUND, + details={"schema_id": ps_node_delete_data.schema_id}, + ) + + try: + await db_user_role_validation_for_list_events_and_process_schema_by_list_event_id( + connection, current_user, process_schema.creator_id + ) + except Exception as e: + raise create_access_error( + message="Access denied", + status_code=status.HTTP_403_FORBIDDEN, + details={"user_id": current_user, "schema_creator_id": process_schema.creator_id, "reason": str(e)}, + ) + + ps_node = await get_ps_node_by_id(connection, ps_node_delete_data.node_id) + if ps_node is None: + raise create_operation_error( + message="PS node not found", + status_code=status.HTTP_404_NOT_FOUND, + details={"node_id": ps_node_delete_data.node_id}, + ) + + next_ps_node = await get_ps_node_by_id(connection, ps_node_delete_data.next_node_id) + if next_ps_node is None: + raise create_operation_error( + message="Next PS node not found", + status_code=status.HTTP_400_BAD_REQUEST, + details={"next_node_id": ps_node_delete_data.next_node_id}, + ) + + is_connected = await check_node_connection( + connection, ps_node_delete_data.node_id, ps_node_delete_data.next_node_id, int(ps_node_delete_data.port) + ) + + if not is_connected: + raise create_validation_error( + message="Node connection validation failed", + status_code=status.HTTP_400_BAD_REQUEST, + details={ + "node_id": ps_node_delete_data.node_id, + "next_node_id": ps_node_delete_data.next_node_id, + "port": ps_node_delete_data.port, + }, + ) + + ordered_node_ids = await get_nodes_for_deletion_ordered(connection, ps_node_delete_data.next_node_id) + + try: + deleted_node_ids = await delete_ps_nodes_sequentially_with_error_handling(connection, ordered_node_ids) + except Exception as e: + raise create_server_error( + message="Failed to delete nodes", + status_code=500, + details={"error": str(e), "ordered_node_ids": ordered_node_ids}, + ) + + return { + "deleted_node_ids": deleted_node_ids, + } + + @api_router.post("", dependencies=[Depends(bearer_schema)], response_model=Ps_NodeFrontResponse) async def create_ps_node_endpoint( ps_node: Ps_NodeRequest, @@ -37,6 +119,10 @@ async def create_ps_node_endpoint( ): user_validation = await get_user_by_login(connection, current_user) + process_schema = await get_process_schema_by_id(connection, ps_node.data["ps_id"]) + if process_schema is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Process schema not found") + registery = VorkNodeRegistry() vork_node = registery.get(ps_node.data["node_type"]) diff --git a/api/api/error/__init__.py b/api/api/error/__init__.py new file mode 100644 index 0000000..c19a55f --- /dev/null +++ b/api/api/error/__init__.py @@ -0,0 +1,26 @@ +""" +Модуль для обработки ошибок API. +""" + +from .error_model.error_types import ServerError, AccessError, OperationError, ValidationError, ErrorType + +from .error_handlers import ( + handle_api_error, + create_server_error, + create_access_error, + create_operation_error, + create_validation_error, +) + +__all__ = [ + "ServerError", + "AccessError", + "OperationError", + "ValidationError", + "ErrorType", + "handle_api_error", + "create_server_error", + "create_access_error", + "create_operation_error", + "create_validation_error", +] diff --git a/api/api/error/error_handlers.py b/api/api/error/error_handlers.py new file mode 100644 index 0000000..a298559 --- /dev/null +++ b/api/api/error/error_handlers.py @@ -0,0 +1,54 @@ +""" +Обработчики ошибок для API. +""" + +from typing import Optional, Dict, Any +from fastapi import HTTPException + +from .error_model.error_types import ServerError, AccessError, OperationError, ValidationError, ErrorType + + +def handle_api_error( + error_type: ErrorType, message: str, status_code: int, details: Optional[Dict[str, Any]] = None +) -> HTTPException: + """ + Функция для создания HTTPException с правильной структурой ошибки. + + """ + match error_type: + case ErrorType.SERVER: + error = ServerError(message=message, details=details) + case ErrorType.ACCESS: + error = AccessError(message=message, details=details) + case ErrorType.OPERATION: + error = OperationError(message=message, details=details) + case ErrorType.VALIDATION: + error = ValidationError(message=message, details=details) + case _: + error = ServerError(message=message, details=details) + + return HTTPException(status_code=status_code, detail=error.model_dump(mode="json")) + + +def create_server_error( + message: str, status_code: int = 500, details: Optional[Dict[str, Any]] = None +) -> HTTPException: + return handle_api_error(error_type=ErrorType.SERVER, message=message, status_code=status_code, details=details) + + +def create_access_error( + message: str, status_code: int = 403, details: Optional[Dict[str, Any]] = None +) -> HTTPException: + return handle_api_error(error_type=ErrorType.ACCESS, message=message, status_code=status_code, details=details) + + +def create_operation_error( + message: str, status_code: int = 400, details: Optional[Dict[str, Any]] = None +) -> HTTPException: + return handle_api_error(error_type=ErrorType.OPERATION, message=message, status_code=status_code, details=details) + + +def create_validation_error( + message: str, status_code: int = 400, details: Optional[Dict[str, Any]] = None +) -> HTTPException: + return handle_api_error(error_type=ErrorType.VALIDATION, message=message, status_code=status_code, details=details) diff --git a/api/api/error/error_model/__init__.py b/api/api/error/error_model/__init__.py new file mode 100644 index 0000000..b9070d1 --- /dev/null +++ b/api/api/error/error_model/__init__.py @@ -0,0 +1,7 @@ +""" +Модели ошибок для API. +""" + +from .error_types import ServerError, AccessError, OperationError, ValidationError, ErrorType + +__all__ = ["ServerError", "AccessError", "OperationError", "ValidationError", "ErrorType"] diff --git a/api/api/error/error_model/error_types.py b/api/api/error/error_model/error_types.py new file mode 100644 index 0000000..92c9bb2 --- /dev/null +++ b/api/api/error/error_model/error_types.py @@ -0,0 +1,56 @@ +""" +Типизированные модели ошибок для API. +""" + +from enum import Enum +from typing import Optional, Dict, Any +from pydantic import BaseModel + + +class ErrorType(str, Enum): + """ + Типы ошибок API. + """ + + SERVER = "SERVER" + ACCESS = "ACCESS" + OPERATION = "OPERATION" + VALIDATION = "VALIDATION" + + +class BaseError(BaseModel): + """ + Базовая модель ошибки. + """ + error_type: ErrorType + message: str + details: Optional[Dict[str, Any]] = None + + +class ServerError(BaseError): + """ + Критические серверные ошибки (БД, соединения и прочие неприятности). + """ + error_type: ErrorType = ErrorType.SERVER + + +class AccessError(BaseError): + """ + Ошибки доступа (несоответствие тенантности, ролям доступа). + """ + error_type: ErrorType = ErrorType.ACCESS + + +class OperationError(BaseError): + """ + Ошибки операции (несоответствие прохождению верификации, ошибки в датасете). + """ + error_type: ErrorType = ErrorType.OPERATION + + +class ValidationError(BaseError): + """ + Ошибки валидации (несоответствие первичной валидации). + """ + error_type: ErrorType = ErrorType.VALIDATION + field_errors: Optional[Dict[str, str]] = None diff --git a/api/api/schemas/process/ps_node.py b/api/api/schemas/process/ps_node.py index 5d75a1a..49d7bb3 100644 --- a/api/api/schemas/process/ps_node.py +++ b/api/api/schemas/process/ps_node.py @@ -6,6 +6,13 @@ from orm.tables.process import NodeStatus, NodeType from api.schemas.base import Base +class Ps_NodeDeleteRequest(Base): + schema_id: int + node_id: int + port: str + next_node_id: int + + class Ps_NodeRequest(Base): data: Dict[str, Any] links: Dict[str, Any]