서론
최근 AI, RAG, AI agent와 관련한 내용이 자주 들리기도 하고, 회사 신입들에게 내줄 프로젝트 주제가 Rag를 활용한 챗봇 이나 AI Agent 구축이라는 주제일 것이라고 들어 Spring으로 구축을 해보았습니다. 초보자가 Rag에 대해 알아가는 과정임으로 틀린 부분이 있을 수도 있습니다. 틀린 부분은 댓글을 달아 주시면 고쳐서 재업로드를 하겠습니다!
RAG(Retrieval-Augmented Generation)란 무엇인가?
RAG는 "검색을 통한 보강 생성"의 약자로, 대형 언어 모델(LLM)이 자체 지식만으로 답변을 생성하는 대신, 외부 지식원(예: 문서, 데이터베이스)에서 관련 정보를 검색하여 그 정보를 바탕으로 더 정확하고 구체적인 답변을 생성하는 방식입니다.
주요 특징:
- 검색 단계: 사용자의 질문에 대해 관련 문서나 데이터를 외부 소스에서 검색합니다.
- 생성 단계: 검색된 정보를 언어 모델에 제공하여 답변을 생성합니다.
이를 통해 RAG는 최신 정보나 특정 주제에 대한 깊이 있는 지식을 필요로 하는 질문에도 효과적으로 대응할 수 있습니다.
이론적인 내용은 앞과 같습니다. 제가 이해한 바로는 쉽게 이해하자면 "프로젝트에 맞춰서 답변을 할 수 있다"로 이해했습니다. 만약 개발 회사라면 각 회사마다 정해진 코드 컨벤션이 있을 것입니다. 그 코드 컨벤션을 vector DB에 기록해놓고 내용을 LLM에 함께 가지고 나서 규칙에 어긋나는 부분을 수정할 수 있게 사용할 수 있음을 의미합니다. 이는 프로젝트마다 각각의 고유한 정보를 활용할 수 있다는 장점이 있습니다. 최근 LLM을 활용법에 관심이 깊어지면서 각 회사마다 챗봇들을 구축하고 있는데 각 고유한 정보들을 vector DB에 저장해서 활용하는 것 같습니다.
그러면 어떤 흐름으로 이루어지는가?
사용자 질의문 -> 사용자 질의문 임베딩 -> vector DB에 있는 임베딩 된 값들 유사도 조회 -> 유사도가 높은 값들과 함께 LLM에 질문과 함께 질의 (vector DB로 부터 온 유사도 높은 문장과 사용자 질의문을 LLM에 보냄) -> 사용자에게 답변 반환
순으로 이어지는 것을 확인할 수 있었습니다.
저는 100번의 강의 시청보다 1번의 실습이 의미있다고 생각하여 직접 구현해보기로 결정했습니다.
Google을 찾아보니 Python으로 구축한 것은 많았으나, Java로 Rag를 구축한 것은 많지 않았습니다. 나중에 Python으로 바꾸더라도 Java/Spring으로 우선 구현을 한번 시도는 해보자해서 시작했습니다.
처음에 vector DB를 결정해야했습니다.
많은 vector DB 중에 어떤 것을 선택해야할까? 고민을 했습니다.
저의 기준은
1. Open Source 여야한다. (무료여야 한다.)
2. 러닝 커브가 낮아야 한다.
3. Vector DB로 탄생한 DB여야 한다.
정도가 있었습니다. 그 후보로 chroma와 Qdrant를 뽑을 수 있었습니다.
처음에는 chroma를 사용하려고 했으나, Python과 node.js 정도만 지원하고 Java/Spring과 연계하려면 살짝 복잡해보이는 것을 느꼈습니다.
그래서 Qdrant를 local에 설치하고 실행했습니다!
Qdrant 설치
docker pull qdrant/qdrant
docker run -p 6333:6333 -p 6334:6334 qdrant/qdrant
그리고 실행하고 http://localhost:6333/dashboard 에 들어가니 다음과 같은 화면을 만날 수 있었습니다.
아마 깔자마자 위와 같은 화면은 아닐 것이고, 저는 excel 데이터 값을 넣느라 이미 collection이 존재해서 다음과 나타났습니다. 화면과 달라도 걱정하지 않아도 됩니다.
저는 gradle 로 구축을 했기에 dependencies는 다음과 같습니다
build.gradle
repositories {
mavenCentral()
maven {url 'https://repo.spring.io/snapshot'} //스냅샷 저장소 추가
}
dependencies {
implementation 'org.springframework.boot:spring-boot-starter-data-jpa'
implementation 'org.springframework.boot:spring-boot-starter-web'
compileOnly 'org.projectlombok:lombok'
runtimeOnly 'com.h2database:h2'
annotationProcessor 'org.projectlombok:lombok'
testImplementation 'org.springframework.boot:spring-boot-starter-test'
testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter:1.0.0-SNAPSHOT'
// Apache POI for Excel
implementation 'org.apache.poi:poi-ooxml:5.2.3'
// Jackson (JSON 파싱용)
implementation 'com.fasterxml.jackson.core:jackson-databind'
}
YML 설정
spring:
ai:
openai:
api-key: <Open AI Key>
qdrant:
base-url: "http://localhost:6333"
collection-name: "excel_collection"
vector-size: 1536
LLM 모델을 사용하다 보니 Open AI Key가 필요합니다. 또, 이를 사용하기 위해선 최소 5달러를 결제해야 합니다 .. ㅠㅠ
Open API 사이트는 다음과 같습니다.
https://platform.openai.com/docs/overview
LLMConfig
@Configuration
public class LLMConfig {
@Value("${spring.ai.openai.api-key}")
private String openaiApiKey;
@Bean
public OpenAiApi openAiApi() {
return new OpenAiApi(openaiApiKey);
}
@Bean
public OpenAiChatOptions openAiChatOptions() {
return OpenAiChatOptions.builder()
.model("gpt-4o-mini")
.temperature(0.0)
.maxCompletionTokens(1024)
.topP(1.0)
.frequencyPenalty(0.0)
.presencePenalty(0.0)
.build();
}
@Bean
public OpenAiChatModel chatModel(OpenAiApi openAiApi, OpenAiChatOptions openAiChatOptions) {
return new OpenAiChatModel(openAiApi, openAiChatOptions);
}
}
LLMConfig는 다음과 같습니다.
특히
@Bean
public OpenAiChatOptions openAiChatOptions() {
return OpenAiChatOptions.builder()
.model("gpt-4o-mini")
.temperature(0.0)
.maxCompletionTokens(1024)
.topP(1.0)
.frequencyPenalty(0.0)
.presencePenalty(0.0)
.build();
}
이 쪽 부분을 중점적으로 봐야합니다.
각 옵션의 의미는 다음과 같습니다.
- temperture: 생성 시 다양성을 조절하는 매개변수 (0~1)
- Rag를 사용할 때는 temperature을 최대한 낮춰야 한다고 합니다. 가장 가능성이 높은 것을 선택할 수 있지만, 그 다음 것을 선택한다던지 다양성을 높이는 변수입니다. 정확성을 높게 하고 싶으면 0에 맞추는 것을 추천합니다.
- 값이 높을수록 창의적이며 낮을수록 결정론적입니다.
- maxCompletionTokens: 생성된 텍스트의 최대 토큰 수
- 최대 output이 16K 정도 되는데 약 16000 토큰 정도를 허가하게 됩니다. 그 최대 토큰을 조절해주는 것입니다.
- 1K 즉 1024 정도로 맞추는게 적당하다고 합니다.
- topP: 다음 단어 선택 시 고려되는 확률 분포의 크기를 지정하는 매개변수 (0.1~1)
- 예) 0.8 -> 상위 80%의 확률을 가진 token 중 다음 단어를 선택
- 높을수록 반응이 다양하며 낮을수록 정확하고 사실적입니다.
- frequencyPenaliy: 단어 반복에 대해서 패널티를 주어서 반복하지 않게 하는 매개변수
- presence_penalty: 새 토큰에 대한 출현빈도를 낮추기 위한 매개변수
저는 위와 같이 설정하고 진행했습니다.
다음부터는 코드를 쭉 보여드리겠습니다.
OpenAiEmbeddingService
@Service
@Transactional
public class OpenAiEmbeddingService {
@Value("${spring.ai.openai.api-key}")
private String openAiApiKey;
private final RestTemplate restTemplate;
private final ObjectMapper objectMapper;
public OpenAiEmbeddingService(RestTemplate restTemplate, ObjectMapper objectMapper) {
this.restTemplate = restTemplate;
this.objectMapper = objectMapper;
}
public List<Double> getEmbedding(String text) {
String url = "https://api.openai.com/v1/embeddings";
// 요청 바디
/*
{
"input": text,
"model": "text-embedding-ada-002"
}
*/
try {
// JSON 구조 만들기 (Jackson)
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("model", "text-embedding-ada-002");
requestMap.put("input", text); // text 안에 어떤 특수문자가 있든 Jackson이 알아서 이스케이프
String requestBody = objectMapper.writeValueAsString(requestMap);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setBearerAuth(openAiApiKey);
HttpEntity<String> entity = new HttpEntity<>(requestBody, headers);
ResponseEntity<String> response = restTemplate.exchange(
url, HttpMethod.POST, entity, String.class
);
if (response.getStatusCode().is2xxSuccessful()) {
JsonNode root = objectMapper.readTree(response.getBody());
JsonNode embeddingArray = root.get("data").get(0).get("embedding");
List<Double> vector = new ArrayList<>();
for (JsonNode val : embeddingArray) {
vector.add(val.asDouble());
}
return vector;
} else {
System.err.println("OpenAI Embedding API error: " + response.getStatusCode());
return null;
}
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
}
QdrantService
@Service
@Transactional
public class QdrantService {
@Value("${qdrant.base-url}")
private String qdrantUrl;
@Value("${qdrant.collection-name}")
private String collectionName;
@Value("${qdrant.vector-size}")
private int vectorSize;
private final RestTemplate restTemplate;
private final ObjectMapper objectMapper;
public QdrantService(RestTemplate restTemplate, ObjectMapper objectMapper) {
this.restTemplate = restTemplate;
this.objectMapper = objectMapper;
}
/**
* 컬렉션 생성
*/
public void createCollectionIfNotExists() {
try {
// GET /collections 로 확인 or 그냥 PUT 시도
String url = qdrantUrl + "/collections/" + collectionName;
// {
// "vectors": {
// "size": 1536,
// "distance": "Cosine"
// }
// }
Map<String, Object> vectorsMap = new HashMap<>();
vectorsMap.put("size", vectorSize);
vectorsMap.put("distance", "Cosine");
Map<String, Object> body = new HashMap<>();
body.put("vectors", vectorsMap);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(body, headers);
restTemplate.put(url, entity); // PUT 요청
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* Qdrant 업서트
* points: (id, vector, payload)
*/
public void upsertPoints(List<Map<String, Object>> points) {
try {
String url = qdrantUrl + "/collections/" + collectionName + "/points?wait=true";
// body = {
// "points": [
// { "id": 1, "vector": [...], "payload": {...} },
// ...
// ]
// }
Map<String, Object> body = new HashMap<>();
body.put("points", points);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(body, headers);
restTemplate.put(url, entity, String.class);
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 검색
* @param queryVector 검색할 벡터
* @param limit 상위 개수
* @return 검색 결과 JSON
*/
public String search(List<Double> queryVector, int limit) {
try {
String url = qdrantUrl + "/collections/" + collectionName + "/points/search";
Map<String, Object> body = new HashMap<>();
body.put("vector", queryVector);
body.put("limit", limit);
body.put("with_payload", true);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<Map<String, Object>> entity = new HttpEntity<>(body, headers);
ResponseEntity<String> resp = restTemplate.postForEntity(url, entity, String.class);
return resp.getBody(); // JSON 스트링
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
}
RagService
여기서 급하게 작성한 코드들이라
qdrantService.createCollectionIfNotExists();
이 부분은 최초 실행시에만 두고 그다음엔 주석처리 해야합니다. 이미 컬렉션이 만들어 졌다는 에러가 나타납니다!!
@Service
@Transactional
public class RagService {
private final ExcelService excelService;
private final OpenAiEmbeddingService embeddingService;
private final QdrantService qdrantService;
public RagService(ExcelService excelService,
OpenAiEmbeddingService embeddingService,
QdrantService qdrantService) {
this.excelService = excelService;
this.embeddingService = embeddingService;
this.qdrantService = qdrantService;
}
public void uploadExcel(InputStream excelStream) {
// 1) 엑셀 "모든 열" 파싱 -> List<Map<String, String>>
List<Map<String, String>> rows = excelService.parseExcelAllColumns(excelStream);
// 2) Qdrant 컬렉션 생성
// qdrantService.createCollectionIfNotExists();
// 3) 각 행 -> 임베딩 -> Qdrant 포인트 목록
List<Map<String, Object>> points = new ArrayList<>();
AtomicInteger idCounter = new AtomicInteger(1);
for (Map<String, String> rowMap : rows) {
// (a) 여러 열의 값을 합쳐 하나의 텍스트로 만듦 (쉼표, 공백 등 구분)
// 예: "Column_0=값0 Column_1=값1 Column_2=값2 ..."
// 혹은 값들만 단순히 이어붙일 수도 있습니다.
String combinedText = combineRowToString(rowMap);
// (b) 임베딩
List<Double> vector = embeddingService.getEmbedding(combinedText);
if (vector == null) {
continue;
}
int id = idCounter.getAndIncrement();
// (c) Qdrant 포인트 구조: {"id", "vector", "payload"}
Map<String, Object> point = new HashMap<>();
point.put("id", id);
point.put("vector", vector);
// (d) payload에 "original_text" 또는 "combinedText" 등 저장
// rowMap 전체를 payload에 넣고 싶다면 payload.putAll(rowMap)로도 가능
Map<String, Object> payload = new HashMap<>();
payload.put("combined_text", combinedText); // 임베딩에 사용한 전체 텍스트
payload.put("excel_data", rowMap); // 원본 전체 열 데이터
point.put("payload", payload);
points.add(point);
}
// 4) 업서트
qdrantService.upsertPoints(points);
}
/**
* 예시: 여러 열의 데이터를 하나의 문자열로 합쳐 주는 헬퍼 메서드
*/
private String combineRowToString(Map<String, String> rowMap) {
// 방법 1) 열 이름 & 값 함께 표시
// Column_0=값0, Column_1=값1, ...
// return rowMap.entrySet().stream()
// .map(e -> e.getKey() + "=" + e.getValue())
// .collect(Collectors.joining(" "));
// 방법 2) 값들만 공백/쉼표로 연결
// 값0 값1 값2 ...
return String.join(" ", rowMap.values());
}
/**
* 질문(쿼리)에 대해 임베딩 -> Qdrant 검색 -> 결과 JSON 반환
*/
public String askQuestion(String question, int topK) {
// 1) 질문 임베딩
List<Double> queryVec = embeddingService.getEmbedding(question);
if (queryVec == null) {
return "Embedding failed or null";
}
// 2) Qdrant 검색
String resultJson = qdrantService.search(queryVec, topK);
return resultJson;
}
}
ExcelService
@Service
@Transactional
public class ExcelService {
// 만약 여러 열을 읽고 싶다면, 예시로 이런 식 메서드도 가능
public List<Map<String, String>> parseExcelAllColumns(InputStream excelInputStream) {
List<Map<String, String>> rows = new ArrayList<>();
try (Workbook workbook = new XSSFWorkbook(excelInputStream)) {
Sheet sheet = workbook.getSheetAt(0);
for (Row row : sheet) {
if (row == null) continue;
// 각 셀(0 ~ row.getLastCellNum()-1) 반복
Map<String, String> rowMap = new HashMap<>();
for (int c = 0; c < row.getLastCellNum(); c++) {
Cell cell = row.getCell(c);
if (cell != null) {
// 셀 타입에 맞게 변환
String value = cell.toString(); // 간단 변환
// 열 이름을 키로 쓸 수도, 'Column_c'처럼 쓸 수도
rowMap.put("Column_" + c, value);
}
}
rows.add(rowMap);
}
} catch (Exception e) {
e.printStackTrace();
}
return rows;
}
}
RagController
@RestController
@RequestMapping("/api/rag")
public class RagController {
private final RagService ragService;
public RagController(RagService ragService) {
this.ragService = ragService;
}
// 1) 엑셀 업로드 -> Qdrant 업서트
@PostMapping("/uploadExcel")
public String uploadExcel(@RequestParam("file") MultipartFile file) {
try {
ragService.uploadExcel(file.getInputStream());
return "Excel data is uploaded & embedded successfully!";
} catch (IOException e) {
return "Error: " + e.getMessage();
}
}
// 2) 질문 -> 검색 (단순 JSON 결과)
@GetMapping("/ask")
public String askQuestion(@RequestParam String question, @RequestParam(defaultValue = "3") int topK) {
return ragService.askQuestion(question, topK);
}
}
이렇게 하고 uploadExcel에 Post 요청을 보내서 엑셀을 업로드하게되면
이런식으로 값이 들어간 것을 확인할 수 있습니다!!
여기서 Open Graph를 누르면
다음과 같이 시각화도 할 수 있고
Find Similar를 누르면 유사한 것도 함께 조회할 수 있습니다.
잘들어 갔는지 확인해본 결과
다음과 같이 정확한 답변을 주고 있습니다!!
이상으로 Java/Spring으로 Rag 맛보기를 해봤는데 너무 설정할 것이 많다고 느꼈습니다. Java/Spring으로는 몇시간씩 구현을 했습니다... 그래서 python으로도 동일한 기능을 똑같이 구현해보았습니다.
python으로는 LangChain이라는 좋은 프레임워크를 제공하고 있어서 사용해봤습니다.
그 결과 30분정도면 다 구현할 수 있었습니다...... 제가 Java, Spring 개발자라서 Java로 구축하면 더 좋겠다 싶었는데 이쪽 분야는 확실히 Python이 강세인 것 같습니다.
그래도 Java/Spring으로도 구현은 가능하다고 느꼈습니다. 더 좋은 라이브러리가 등장한다면 충분히 할 수 있을 것 같습니다. (아직까진 python이라 저는 python으로 개발을 더 진행하였습니다.. ㅎㅎ..)
※ 다시한번 말하지만 연습삼아서 한 부분이 많고, 생소한 분야라 이해가 깊지 않습니다. 이런식으로하면 체험정도는 할 수 있구나~ 라는 글입니다.
'Project' 카테고리의 다른 글
[Project] TDD 개발 시 given의 코드 양 줄이기 (Fixture, Persister) + @spy 활용 (1) | 2024.06.08 |
---|---|
[Project] 실시간 채팅 구현 시 FCM Token 발송 여부 결정하기 (0) | 2024.03.25 |
[Project] Spring + Stomp 테스트 하는 과정.. (실시간 채팅 구현) (6) | 2024.02.27 |
[Project] 멀티 모듈을 왜 쓸까? (멀티 모듈 도입 전) (1) | 2024.01.15 |
[Project] google Oauth2 로그인 시 refreshToken을 받는 방법 (최초 로그인에서 저장을 못했을 때) (0) | 2023.11.29 |