From 08affe251b5b04855b817d1f960efc6bc05e1b5f Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Mon, 15 Jun 2026 11:25:02 +0800 Subject: [PATCH] feat: add ImageToVideo and TextToVideo model credential classes with validation and encryption --- .../gemini_model_provider/credential/itv.py | 92 ++++++++++++ .../gemini_model_provider/credential/ttv.py | 98 +++++++++++++ .../gemini_model_provider.py | 134 ++++++++++++------ .../impl/gemini_model_provider/model/ttv.py | 132 +++++++++++++++++ 4 files changed, 412 insertions(+), 44 deletions(-) create mode 100644 apps/models_provider/impl/gemini_model_provider/credential/itv.py create mode 100644 apps/models_provider/impl/gemini_model_provider/credential/ttv.py create mode 100644 apps/models_provider/impl/gemini_model_provider/model/ttv.py diff --git a/apps/models_provider/impl/gemini_model_provider/credential/itv.py b/apps/models_provider/impl/gemini_model_provider/credential/itv.py new file mode 100644 index 00000000000..7105472eb60 --- /dev/null +++ b/apps/models_provider/impl/gemini_model_provider/credential/itv.py @@ -0,0 +1,92 @@ +# coding=utf-8 + +from typing import Dict, Any + +from django.utils.translation import gettext_lazy as _, gettext + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, PasswordInputField +from models_provider.base_model_provider import BaseModelCredential, ValidCode +from common.utils.logger import maxkb_logger + + + +class ImageToVideoModelCredential(BaseForm, BaseModelCredential): + """ + Credential class for the Qwen Image-to-Video model. + Provides validation and encryption for the model credentials. + """ + + base_url = forms.TextInputField(_("Base Url"), required=True, default_value="https://generativelanguage.googleapis.com") + api_key = PasswordInputField("API Key", required=True) + + def is_valid( + self, + model_type: str, + model_name: str, + model_credential: Dict[str, Any], + model_params: Dict[str, Any], + provider, + raise_exception: bool = False, + ) -> bool: + """ + Validate the model credentials. + + :param model_type: Type of the model (e.g., 'TEXT_TO_Video'). + :param model_name: Name of the model. + :param model_credential: Dictionary containing the model credentials. + :param model_params: Parameters for the model. + :param provider: Model provider instance. + :param raise_exception: Whether to raise an exception on validation failure. + :return: Boolean indicating whether the credentials are valid. + """ + model_type_list = provider.get_model_type_list() + if not any(mt.get("value") == model_type for mt in model_type_list): + raise AppApiException( + ValidCode.valid_error.value, + gettext("{model_type} Model type is not supported").format(model_type=model_type), + ) + + required_keys = ["api_key", "base_url"] + for key in required_keys: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, gettext("{key} is required").format(key=key)) + return False + + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + res = model.check_auth() + except Exception as e: + maxkb_logger.error(f"Exception: {e}", exc_info=True) + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException( + ValidCode.valid_error.value, + gettext("Verification failed, please check whether the parameters are correct: {error}").format( + error=str(e) + ), + ) + return False + + return True + + def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: + """ + Encrypt sensitive fields in the model dictionary. + + :param model: Dictionary containing model details. + :return: Dictionary with encrypted sensitive fields. + """ + return {**model, "api_key": super().encryption(model.get("api_key", ""))} + + def get_model_params_setting_form(self, model_name: str): + """ + Get the parameter setting form for the specified model. + + :param model_name: Name of the model. + :return: Parameter setting form. + """ + pass diff --git a/apps/models_provider/impl/gemini_model_provider/credential/ttv.py b/apps/models_provider/impl/gemini_model_provider/credential/ttv.py new file mode 100644 index 00000000000..cf0e83b1d46 --- /dev/null +++ b/apps/models_provider/impl/gemini_model_provider/credential/ttv.py @@ -0,0 +1,98 @@ +# coding=utf-8 + +from typing import Dict, Any + +from django.utils.translation import gettext_lazy as _, gettext + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel +from common.forms.switch_field import SwitchField +from models_provider.base_model_provider import BaseModelCredential, ValidCode +from common.utils.logger import maxkb_logger + + + +class TextToVideoModelCredential(BaseForm, BaseModelCredential): + """ + Credential class for the Qwen Text-to-Video model. + Provides validation and encryption for the model credentials. + """ + base_url = forms.TextInputField(_("Base Url"), required=True, default_value="https://generativelanguage.googleapis.com") + api_key = PasswordInputField('API Key', required=True) + + def is_valid( + self, + model_type: str, + model_name: str, + model_credential: Dict[str, Any], + model_params: Dict[str, Any], + provider, + raise_exception: bool = False + ) -> bool: + """ + Validate the model credentials. + + :param model_type: Type of the model (e.g., 'TEXT_TO_Video'). + :param model_name: Name of the model. + :param model_credential: Dictionary containing the model credentials. + :param model_params: Parameters for the model. + :param provider: Model provider instance. + :param raise_exception: Whether to raise an exception on validation failure. + :return: Boolean indicating whether the credentials are valid. + """ + model_type_list = provider.get_model_type_list() + if not any(mt.get('value') == model_type for mt in model_type_list): + raise AppApiException( + ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type) + ) + + required_keys = ['api_key', 'base_url'] + for key in required_keys: + if key not in model_credential: + if raise_exception: + raise AppApiException( + ValidCode.valid_error.value, + gettext('{key} is required').format(key=key) + ) + return False + + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + res = model.check_auth() + except Exception as e: + maxkb_logger.error(f'Exception: {e}', exc_info=True) + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException( + ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}' + ).format(error=str(e)) + ) + return False + + return True + + def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: + """ + Encrypt sensitive fields in the model dictionary. + + :param model: Dictionary containing model details. + :return: Dictionary with encrypted sensitive fields. + """ + return { + **model, + 'api_key': super().encryption(model.get('api_key', '')) + } + + def get_model_params_setting_form(self, model_name: str): + """ + Get the parameter setting form for the specified model. + + :param model_name: Name of the model. + :return: Parameter setting form. + """ + pass diff --git a/apps/models_provider/impl/gemini_model_provider/gemini_model_provider.py b/apps/models_provider/impl/gemini_model_provider/gemini_model_provider.py index c8b38fb1b1b..8dfb30d377e 100644 --- a/apps/models_provider/impl/gemini_model_provider/gemini_model_provider.py +++ b/apps/models_provider/impl/gemini_model_provider/gemini_model_provider.py @@ -1,21 +1,29 @@ #!/usr/bin/env python # -*- coding: UTF-8 -*- """ -@Project :MaxKB +@Project :MaxKB @File :gemini_model_provider.py @Author :Brian Yang -@Date :5/13/24 7:47 AM +@Date :5/13/24 7:47 AM """ + import os from common.utils.common import get_file_content -from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ - ModelInfoManage +from models_provider.base_model_provider import ( + IModelProvider, + ModelProvideInfo, + ModelInfo, + ModelTypeConst, + ModelInfoManage, +) from models_provider.impl.gemini_model_provider.credential.embedding import GeminiEmbeddingCredential from models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential +from models_provider.impl.gemini_model_provider.credential.itv import ImageToVideoModelCredential from models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential from models_provider.impl.gemini_model_provider.credential.stt import GeminiSTTModelCredential from models_provider.impl.gemini_model_provider.credential.tti import GeminiTextToImageModelCredential +from models_provider.impl.gemini_model_provider.credential.ttv import TextToVideoModelCredential from models_provider.impl.gemini_model_provider.model.embedding import GeminiEmbeddingModel from models_provider.impl.gemini_model_provider.model.image import GeminiImage from models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel @@ -24,64 +32,93 @@ from django.utils.translation import gettext as _ from models_provider.impl.gemini_model_provider.model.tti import GeminiTextToImage +from models_provider.impl.gemini_model_provider.model.ttv import GenerationVideoModel gemini_llm_model_credential = GeminiLLMModelCredential() gemini_image_model_credential = GeminiImageModelCredential() gemini_stt_model_credential = GeminiSTTModelCredential() gemini_embedding_model_credential = GeminiEmbeddingCredential() gemini_tti_model_credential = GeminiTextToImageModelCredential() +gemini_itv_model_credential = ImageToVideoModelCredential() +gemini_ttv_model_credential = TextToVideoModelCredential() model_info_list = [ - ModelInfo('gemini-1.0-pro', _('Latest Gemini 1.0 Pro model, updated with Google update'), - ModelTypeConst.LLM, - gemini_llm_model_credential, - GeminiChatModel), - ModelInfo('gemini-1.0-pro-vision', _('Latest Gemini 1.0 Pro Vision model, updated with Google update'), - ModelTypeConst.LLM, - gemini_llm_model_credential, - GeminiChatModel), + ModelInfo( + "gemini-1.0-pro", + _("Latest Gemini 1.0 Pro model, updated with Google update"), + ModelTypeConst.LLM, + gemini_llm_model_credential, + GeminiChatModel, + ), + ModelInfo( + "gemini-1.0-pro-vision", + _("Latest Gemini 1.0 Pro Vision model, updated with Google update"), + ModelTypeConst.LLM, + gemini_llm_model_credential, + GeminiChatModel, + ), ] model_image_info_list = [ - ModelInfo('gemini-1.5-flash', _('Latest Gemini 1.5 Flash model, updated with Google updates'), - ModelTypeConst.IMAGE, - gemini_image_model_credential, - GeminiImage), - ModelInfo('gemini-1.5-pro', _('Latest Gemini 1.5 Flash model, updated with Google updates'), - ModelTypeConst.IMAGE, - gemini_image_model_credential, - GeminiImage), + ModelInfo( + "gemini-1.5-flash", + _("Latest Gemini 1.5 Flash model, updated with Google updates"), + ModelTypeConst.IMAGE, + gemini_image_model_credential, + GeminiImage, + ), + ModelInfo( + "gemini-1.5-pro", + _("Latest Gemini 1.5 Flash model, updated with Google updates"), + ModelTypeConst.IMAGE, + gemini_image_model_credential, + GeminiImage, + ), ] model_stt_info_list = [ - ModelInfo('gemini-1.5-flash', _('Latest Gemini 1.5 Flash model, updated with Google updates'), - ModelTypeConst.STT, - gemini_stt_model_credential, - GeminiSpeechToText), - ModelInfo('gemini-1.5-pro', _('Latest Gemini 1.5 Flash model, updated with Google updates'), - ModelTypeConst.STT, - gemini_stt_model_credential, - GeminiSpeechToText), + ModelInfo( + "gemini-1.5-flash", + _("Latest Gemini 1.5 Flash model, updated with Google updates"), + ModelTypeConst.STT, + gemini_stt_model_credential, + GeminiSpeechToText, + ), + ModelInfo( + "gemini-1.5-pro", + _("Latest Gemini 1.5 Flash model, updated with Google updates"), + ModelTypeConst.STT, + gemini_stt_model_credential, + GeminiSpeechToText, + ), ] model_embedding_info_list = [ - ModelInfo('models/embedding-001', '', - ModelTypeConst.EMBEDDING, - gemini_embedding_model_credential, - GeminiEmbeddingModel), - ModelInfo('models/text-embedding-004', '', - ModelTypeConst.EMBEDDING, - gemini_embedding_model_credential, - GeminiEmbeddingModel), + ModelInfo( + "models/embedding-001", "", ModelTypeConst.EMBEDDING, gemini_embedding_model_credential, GeminiEmbeddingModel + ), + ModelInfo( + "models/text-embedding-004", + "", + ModelTypeConst.EMBEDDING, + gemini_embedding_model_credential, + GeminiEmbeddingModel, + ), ] model_tti_info_list = [ - ModelInfo('gemini-3.1-flash-image-preview', "", - ModelTypeConst.TTI, - gemini_tti_model_credential, - GeminiTextToImage) + ModelInfo("gemini-3.1-flash-image-preview", "", ModelTypeConst.TTI, gemini_tti_model_credential, GeminiTextToImage) +] + +ttv_model_info_list = [ + ModelInfo("veo-3.1-generate-preview", "", ModelTypeConst.TTV, gemini_ttv_model_credential, GenerationVideoModel) ] +itv_model_info_list = [ + ModelInfo("veo-3.1-generate-preview", "", ModelTypeConst.ITV, gemini_itv_model_credential, GenerationVideoModel) +] + + model_info_manage = ( ModelInfoManage.builder() .append_model_info_list(model_info_list) @@ -94,16 +131,25 @@ .append_default_model_info(model_stt_info_list[0]) .append_default_model_info(model_embedding_info_list[0]) .append_default_model_info(model_tti_info_list[0]) + .append_model_info_list(ttv_model_info_list) + .append_default_model_info(ttv_model_info_list[0]) + .append_model_info_list(itv_model_info_list) + .append_default_model_info(itv_model_info_list[0]) .build() ) class GeminiModelProvider(IModelProvider): - def get_model_info_manage(self): return model_info_manage def get_model_provide_info(self): - return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content( - os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'gemini_model_provider', 'icon', - 'gemini_icon_svg'))) + return ModelProvideInfo( + provider="model_gemini_provider", + name="Gemini", + icon=get_file_content( + os.path.join( + PROJECT_DIR, "apps", "models_provider", "impl", "gemini_model_provider", "icon", "gemini_icon_svg" + ) + ), + ) diff --git a/apps/models_provider/impl/gemini_model_provider/model/ttv.py b/apps/models_provider/impl/gemini_model_provider/model/ttv.py new file mode 100644 index 00000000000..f3df1e50baf --- /dev/null +++ b/apps/models_provider/impl/gemini_model_provider/model/ttv.py @@ -0,0 +1,132 @@ +import base64 +import time +from typing import Dict +import requests + +from common.utils.logger import maxkb_logger +from models_provider.base_model_provider import MaxKBBaseModel +from models_provider.base_ttv import BaseGenerationVideo + + +class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo): + base_url: str + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get("api_key") + self.base_url = kwargs.get("base_url") + self.model = kwargs.get("model") + self.params = kwargs.get("params") + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {"params": {}} + for key, value in model_kwargs.items(): + if key not in ["model_id", "use_local", "streaming"]: + optional_params["params"][key] = value + return GenerationVideoModel( + model=model_name, + base_url=model_credential.get("base_url", "https://generativelanguage.googleapis.com"), + api_key=model_credential.get("api_key"), + **optional_params, + ) + + def check_auth(self): + return True + + def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, last_frame_url=None, **kwargs): + from google import genai + from google.genai import types + client = genai.Client(api_key=self.api_key, http_options={"base_url": self.base_url}) + + # 1. 动态构建 Config 参数字典 + config_params = {} + if self.params.get("aspect_ratio"): + config_params["aspect_ratio"] = self.params["aspect_ratio"] + if self.params.get("resolution"): + config_params["resolution"] = self.params["resolution"] + + try: + # 2. 初始化核心请求参数(文生视频的基础) + operation_args = { + "model": self.model, + "prompt": prompt, + } + + # 3. 处理首帧(图生视频) + if first_frame_url: + maxkb_logger.info("Processing first frame...") + operation_args["image"] = self._load_image_as_sdk_type(first_frame_url) + + # 4. 处理尾帧(图生视频) + if last_frame_url: + maxkb_logger.info("Processing last frame...") + config_params["last_frame"] = self._load_image_as_sdk_type(last_frame_url) + + # 5. 统一组装视频配置(无论是宽高比还是尾帧,都统一在这里安全实例化) + if config_params: + operation_args["config"] = types.GenerateVideosConfig(**config_params) + + # 6. 发起异步生成任务 + maxkb_logger.info(f"Starting video generation with model: {operation_args['model']}") + operation = client.models.generate_videos(**operation_args) + + # 7. 安全轮询任务状态 + max_retries = 120 + retry_count = 0 + wait_time = 10 + + while not operation.done and retry_count < max_retries: + maxkb_logger.info(f"Waiting for video generation to complete... ({retry_count * wait_time}s)") + time.sleep(wait_time) + operation = client.operations.get(operation) + retry_count += 1 + + if not operation.done: + raise TimeoutError("Video generation timed out after 20 minutes") + + # 8. 异常与结果检查 + if operation.error: + raise Exception(f"Video generation failed from Google Side: {operation.error}") + + if not operation.result or not operation.result.generated_videos: + raise Exception("Google API returned empty result.") + + generated_video_obj = operation.result.generated_videos[0] + video_file_ref = generated_video_obj.video + + # 9. 下载视频字节流 + maxkb_logger.info("Downloading video bytes...") + video_bytes = client.files.download(file=video_file_ref) + + return video_bytes + + except Exception as e: + maxkb_logger.error(f"Video generation error: {str(e)}") + raise + + def _load_image_as_sdk_type(self, image_url: str): + """ + 统一从 URL 或 base64 加载图片并构造为包含 bytes 的 types.Image 对象。 + """ + from google.genai import types + + if image_url.startswith("data:"): + header, encoded = image_url.split(",", 1) + mime_type = header.split(";")[0].split(":")[1] + image_bytes = base64.b64decode(encoded) + else: + response = requests.get(image_url, timeout=15) + response.raise_for_status() + mime_type = response.headers.get("Content-Type", "image/jpeg") + image_bytes = response.content + + # 注意:新 SDK 允许你不显式传 mime_type,但传入会更稳妥 + return types.Image(image_bytes=image_bytes, mime_type=mime_type)