synthetic-data-es / pipelines /08_rag_grounded_embeddings.py
pauvanbr's picture
Add Pipeline 08: RAG-Grounded Embeddings (grounded triplets from real docs)
e6b3c59 verified
#!/usr/bin/env python3
"""
Pipeline 8: RAG-Grounded Embeddings — Triplets fundamentados en documentos reales
==================================================================================
Método: Mejora sobre E5-Mistral (arxiv:2401.00368) y Gecko (arxiv:2403.20327)
- E5 genera triplets puramente sintéticos (query, positive, negative) con LLM
- Problema: el "positive_document" es alucinado por el LLM → puede no ser factual
- Solución: usar documentos REALES como positive, generar query y hard_negative condicionados
Ventaja:
- El positive_document es un chunk real → cero alucinación en la información de retrieval
- El query se genera condicionado al documento → naturalmente relevante
- El hard_negative se genera con instrucciones de ser temáticamente similar pero no responder
Usa GenerateSentencePair de distilabel (action="query", triplet=True):
- Input: "anchor" (documento real)
- Output: "positive" (query que el documento responde), "negative" (hard negative)
Pipeline:
1. LoadDataFromHub — Carga documentos en español
2. ChunkForEmbeddings — Divide en chunks de ~300 palabras (óptimo para embeddings)
3. GenerateSentencePair — Genera query + hard_negative para cada chunk
4. FormatForSentenceTransformers — Renombra columnas al formato ST
5. KeepColumns — Output final
Output: Dataset con columnas (query, positive, negative, instruction, task_type, language, source_title)
Compatible con sentence-transformers MultipleNegativesRankingLoss.
Fuentes soportadas:
- wikimedia/wikipedia (config 20231101.es)
- allenai/c4 (config es)
- Cualquier dataset HF con columna de texto
Uso:
python 08_rag_grounded_embeddings.py \\
--model_id "meta-llama/Meta-Llama-3.1-70B-Instruct" \\
--source_repo "wikimedia/wikipedia" \\
--source_config "20231101.es" \\
--num_docs 10000 \\
--output_repo "tu-org/rag-embedding-triplets-es"
# Multi-acción (query + semantic en paralelo):
python 08_rag_grounded_embeddings.py \\
--multi_action \\
--output_repo "tu-org/rag-embedding-triplets-es"
"""
import argparse
import os
from typing import Generator, Optional
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.steps.generators.huggingface import LoadDataFromHub
from distilabel.steps import KeepColumns, StepInput, step
# ---------------------------------------------------------------------------
# Configuración global (inyectada desde build_pipeline)
# ---------------------------------------------------------------------------
_EMBED_CHUNK_MAX_TOKENS = 300
_EMBED_CHUNK_OVERLAP = 50
# ---------------------------------------------------------------------------
# Custom steps
# ---------------------------------------------------------------------------
@step(inputs=["text"], outputs=["anchor", "title", "url", "chunk_id"])
def ChunkForEmbeddings(inputs: StepInput) -> Generator:
"""Divide documentos en chunks optimizados para embeddings.
Para retrieval, los chunks deben ser más cortos que para SFT:
- E5-Mistral usa ~200 tokens para positive_document
- Gecko usa ~256 tokens
- Nosotros: 300 tokens (~225 palabras en español)
Produce columna "anchor" que es lo que espera GenerateSentencePair.
"""
max_chunk_tokens = _EMBED_CHUNK_MAX_TOKENS
overlap_tokens = _EMBED_CHUNK_OVERLAP
for input_batch in [inputs]:
outputs = []
for row in input_batch:
text = row.get("text", "")
title = row.get("title", "")
url = row.get("url", "")
if not text or len(text.split()) < 30:
continue
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
current_chunk = []
current_words = 0
chunk_id = 0
for para in paragraphs:
para_words = len(para.split())
if current_words + para_words > max_chunk_tokens and current_chunk:
chunk_text = "\n\n".join(current_chunk)
if len(chunk_text.split()) >= 30:
outputs.append({
"anchor": chunk_text,
"title": title,
"url": url,
"chunk_id": f"{title}_{chunk_id}",
})
chunk_id += 1
if overlap_tokens > 0 and current_chunk:
last = current_chunk[-1]
current_chunk = [last]
current_words = len(last.split())
else:
current_chunk = []
current_words = 0
current_chunk.append(para)
current_words += para_words
if current_chunk:
chunk_text = "\n\n".join(current_chunk)
if len(chunk_text.split()) >= 30:
outputs.append({
"anchor": chunk_text,
"title": title,
"url": url,
"chunk_id": f"{title}_{chunk_id}",
})
yield outputs
@step(
inputs=["anchor", "positive", "negative", "title"],
outputs=["query", "positive", "negative", "instruction", "task_type", "language", "source_title"]
)
def FormatForSentenceTransformers(inputs: StepInput) -> Generator:
"""Convierte el output de GenerateSentencePair al formato sentence-transformers.
GenerateSentencePair con action="query" produce:
- anchor: el documento original
- positive: la query generada
- negative: un hard negative
Para sentence-transformers queremos:
- query: la pregunta/consulta de búsqueda
- positive: el documento que la responde (= nuestro anchor original)
- negative: el hard negative
"""
for input_batch in [inputs]:
outputs = []
for row in input_batch:
anchor = row.get("anchor", "")
positive = row.get("positive", "") # query generada por el LLM
negative = row.get("negative", "")
title = row.get("title", "")
if not anchor or not positive:
continue
# Swap: query=positive (query generada), positive=anchor (doc real)
outputs.append({
"query": positive,
"positive": anchor,
"negative": negative,
"instruction": "Retrieve a relevant Spanish document that answers the query",
"task_type": "short-long",
"language": "spanish",
"source_title": title,
})
yield outputs
# ---------------------------------------------------------------------------
# Pipeline builders
# ---------------------------------------------------------------------------
def build_pipeline(
model_id: str,
source_repo: str,
source_config: Optional[str],
source_column: str,
num_docs: int,
output_repo: str,
batch_size: int = 10,
max_chunk_tokens: int = 300,
temperature: float = 0.7,
max_new_tokens: int = 512,
private: bool = True,
action: str = "query",
context: str = "",
) -> Pipeline:
"""Construye el pipeline de embedding triplets grounded (single action)."""
global _EMBED_CHUNK_MAX_TOKENS
_EMBED_CHUNK_MAX_TOKENS = max_chunk_tokens
if not context:
context = (
"Los textos generados deben estar relacionados con documentos en español "
"sobre temas variados: ciencia, historia, tecnología, cultura, sociedad, "
"economía, derecho, medicina y educación del mundo hispanohablante."
)
with Pipeline(
name="rag-grounded-embeddings-es",
description=(
"Generación de triplets de embedding fundamentados en documentos reales en español. "
"Los positives son documentos reales (sin alucinación)."
),
) as pipeline:
llm = InferenceEndpointsLLM(
model_id=model_id,
tokenizer_id=model_id,
generation_kwargs={
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"do_sample": True,
},
)
load_kwargs = {
"name": "load_documents",
"repo_id": source_repo,
"split": "train",
"num_examples": num_docs,
"batch_size": batch_size,
}
if source_config:
load_kwargs["config"] = source_config
loader = LoadDataFromHub(**load_kwargs)
chunker = ChunkForEmbeddings(name="chunk_for_embeddings")
gen_pairs = GenerateSentencePair(
name="generate_pairs",
llm=llm,
action=action,
triplet=True,
context=context,
num_generations=1,
input_batch_size=batch_size,
)
formatter = FormatForSentenceTransformers(name="format_st")
keep = KeepColumns(
name="keep_columns",
columns=["query", "positive", "negative", "instruction", "task_type", "language", "source_title"],
)
loader >> chunker >> gen_pairs >> formatter >> keep
return pipeline
def build_multi_action_pipeline(
model_id: str,
source_repo: str,
source_config: Optional[str],
source_column: str,
num_docs: int,
output_repo: str,
batch_size: int = 10,
max_chunk_tokens: int = 300,
temperature: float = 0.7,
max_new_tokens: int = 512,
private: bool = True,
) -> Pipeline:
"""Pipeline que genera triplets para múltiples tipos de tarea.
Produce 2 tipos en paralelo (siguiendo taxonomía E5-Mistral):
1. query: consulta de búsqueda → documento (asymmetric, short-long) — 70% del training
2. semantically-similar: frase similar al chunk (symmetric) — 30% del training
"""
global _EMBED_CHUNK_MAX_TOKENS
_EMBED_CHUNK_MAX_TOKENS = max_chunk_tokens
contexts = {
"query": (
"Genera una consulta de búsqueda corta y natural en español que un usuario "
"escribiría en Google para encontrar este documento."
),
"semantically-similar": (
"Genera una frase en español que exprese la misma información de forma diferente, "
"como una paráfrasis o reformulación del contenido."
),
}
with Pipeline(
name="rag-grounded-embeddings-multi-es",
description="Triplets multi-acción para embedding training diverso",
) as pipeline:
llm = InferenceEndpointsLLM(
model_id=model_id,
tokenizer_id=model_id,
generation_kwargs={
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"do_sample": True,
},
)
load_kwargs = {
"name": "load_documents",
"repo_id": source_repo,
"split": "train",
"num_examples": num_docs,
"batch_size": batch_size,
}
if source_config:
load_kwargs["config"] = source_config
loader = LoadDataFromHub(**load_kwargs)
chunker = ChunkForEmbeddings(name="chunk_for_embeddings")
gen_query = GenerateSentencePair(
name="gen_query_pairs",
llm=llm,
action="query",
triplet=True,
context=contexts["query"],
num_generations=1,
input_batch_size=batch_size,
)
gen_semantic = GenerateSentencePair(
name="gen_semantic_pairs",
llm=llm,
action="semantically-similar",
triplet=True,
context=contexts["semantically-similar"],
num_generations=1,
input_batch_size=batch_size,
)
fmt_query = FormatForSentenceTransformers(name="format_query")
fmt_semantic = FormatForSentenceTransformers(name="format_semantic")
keep_query = KeepColumns(
name="keep_query",
columns=["query", "positive", "negative", "instruction", "task_type", "language", "source_title"],
)
keep_semantic = KeepColumns(
name="keep_semantic",
columns=["query", "positive", "negative", "instruction", "task_type", "language", "source_title"],
)
# Fan-out: chunks → ambos generadores en paralelo
loader >> chunker >> [gen_query, gen_semantic]
gen_query >> fmt_query >> keep_query
gen_semantic >> fmt_semantic >> keep_semantic
return pipeline
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="RAG-Grounded Embeddings: Triplets fundamentados en documentos en español"
)
parser.add_argument("--model_id", type=str, default="meta-llama/Meta-Llama-3.1-70B-Instruct")
parser.add_argument("--source_repo", type=str, default="wikimedia/wikipedia")
parser.add_argument("--source_config", type=str, default="20231101.es")
parser.add_argument("--source_column", type=str, default="text")
parser.add_argument("--num_docs", type=int, default=10000)
parser.add_argument("--output_repo", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=10)
parser.add_argument("--max_chunk_tokens", type=int, default=300)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max_new_tokens", type=int, default=512)
parser.add_argument("--private", action="store_true", default=True)
parser.add_argument("--action", type=str, default="query",
choices=["query", "semantically-similar", "answer"])
parser.add_argument("--multi_action", action="store_true", default=False,
help="Usar pipeline multi-acción (query + semantic en paralelo)")
parser.add_argument("--context", type=str, default="",
help="Contexto adicional para el generador de pares")
args = parser.parse_args()
config = args.source_config if args.source_config != "none" else None
if args.multi_action:
pipeline = build_multi_action_pipeline(
model_id=args.model_id, source_repo=args.source_repo,
source_config=config, source_column=args.source_column,
num_docs=args.num_docs, output_repo=args.output_repo,
batch_size=args.batch_size, max_chunk_tokens=args.max_chunk_tokens,
temperature=args.temperature, max_new_tokens=args.max_new_tokens,
private=args.private,
)
else:
pipeline = build_pipeline(
model_id=args.model_id, source_repo=args.source_repo,
source_config=config, source_column=args.source_column,
num_docs=args.num_docs, output_repo=args.output_repo,
batch_size=args.batch_size, max_chunk_tokens=args.max_chunk_tokens,
temperature=args.temperature, max_new_tokens=args.max_new_tokens,
private=args.private, action=args.action, context=args.context,
)
distiset = pipeline.run(use_cache=True)
distiset.push_to_hub(
args.output_repo,
include_script=True,
private=args.private,
token=os.environ.get("HF_TOKEN"),
)
print(f"\n✅ Dataset generado y subido a: https://huggingface.co/datasets/{args.output_repo}")
# Crear versión combinada para sentence-transformers
try:
from datasets import concatenate_datasets
all_ds = []
for key in distiset:
if "train" in distiset[key]:
all_ds.append(distiset[key]["train"])
if all_ds:
combined = concatenate_datasets(all_ds) if len(all_ds) > 1 else all_ds[0]
combined.push_to_hub(
f"{args.output_repo}-st-format",
private=args.private,
token=os.environ.get("HF_TOKEN"),
)
print(f"✅ ST-format: https://huggingface.co/datasets/{args.output_repo}-st-format")
except Exception as e:
print(f"⚠️ No se pudo crear formato ST: {e}")
if __name__ == "__main__":
main()