refactor using openai

This commit is contained in:
robinrolle
2025-04-13 01:13:20 +02:00
parent a440aee411
commit 537527fba0
3 changed files with 7 additions and 6 deletions

View File

@ -7,7 +7,7 @@ from dto.requests import GameStartRequestDTO, GameDecisionRequestDTO
from services.extractor import extract_profile, extract_passport, extract_description, extract_account
from services.julius_baer_api_client import JuliusBaerApiClient
from utils.storage.game_files_manager import store_game_round_data
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai.chat_models import ChatOpenAI
from validation.llm_validate import AdvisorDecision
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
@ -129,7 +129,7 @@ class Advisor:
)
# 4. Chaîne LLM
chain = prompt | ChatGoogleGenerativeAI(model="gemini-2.0-flash") | parser
chain = prompt | ChatOpenAI(model="gpt-4o-mini") | parser
# 5. Invocation
result: AdvisorDecision = chain.invoke({

View File

@ -4,7 +4,7 @@ from typing import Callable, Type, Any, TypeVar
from langchain_core.runnables import Runnable
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai.chat_models import ChatOpenAI
from pydantic import BaseModel
from utils.parsers import process_profile, process_passport, process_account
@ -73,7 +73,7 @@ def extract_passport(client_data: dict[str, Any]) -> FromPassport:
def extract_profile(client_data: dict[str, Any]) -> FromProfile:
passport_data = client_data.get("profile")
profile_data = client_data.get("profile")
prompt_template = (
"Extract the following information from the provided text.\n"
@ -83,7 +83,7 @@ def extract_profile(client_data: dict[str, Any]) -> FromProfile:
)
result = __run_extraction_chain(
raw_file_data=passport_data,
raw_file_data=profile_data,
file_processor=process_profile,
pydantic_model=FromProfile,
prompt_template=prompt_template,
@ -125,7 +125,7 @@ def __run_extraction_chain(
prompt = ChatPromptTemplate.from_template(prompt_template)
chain: Runnable = prompt | ChatGoogleGenerativeAI(model=model_name) | parser
chain: Runnable = prompt | ChatOpenAI(model="gpt-4o-mini") | parser
result = chain.invoke({
"processed_text": processed_text,