Commandes MLX
MLX est le framework d’apprentissage automatique d’Apple conçu pour Apple Silicon. Il propose l’évaluation paresseuse, des opérations unifiées CPU/GPU, un modèle de mémoire unifiée et une API familière de type NumPy pour les chercheurs.
Installation
# Install MLX core
pip install mlx
# Install MLX LM tools (inference, fine-tuning, conversion)
pip install mlx-lm
# Install MLX for vision and audio
pip install mlx-vlm
pip install mlx-whisper
# Verify installation
python -c "import mlx.core as mx; print(mx.default_device())"
Génération de texte LLM
# Generate text from a HuggingFace model (auto-downloads)
mlx_lm.generate \
--model mlx-community/Llama-3.1-8B-Instruct-4bit \
--prompt "Explain transformers in ML:" \
--max-tokens 256
# With sampling parameters
mlx_lm.generate \
--model mlx-community/Mistral-7B-Instruct-v0.3-4bit \
--prompt "Write a haiku about coding:" \
--max-tokens 100 \
--temp 0.7 \
--top-p 0.9
# Chat mode
mlx_lm.chat \
--model mlx-community/Llama-3.1-8B-Instruct-4bit
API Python pour la génération
from mlx_lm import load, generate
# Load model and tokenizer
model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
# Generate text
prompt = "Explain gradient descent:"
response = generate(
model,
tokenizer,
prompt=prompt,
max_tokens=256,
temp=0.7,
)
print(response)
# Streaming generation
from mlx_lm import stream_generate
for token in stream_generate(model, tokenizer, prompt=prompt, max_tokens=256):
print(token, end="", flush=True)
Conversion de modèles
# Convert HuggingFace model to MLX format
mlx_lm.convert \
--hf-path meta-llama/Llama-3.1-8B-Instruct \
--mlx-path ./mlx-llama-8b \
--dtype float16
# Convert with 4-bit quantization
mlx_lm.convert \
--hf-path meta-llama/Llama-3.1-8B-Instruct \
--mlx-path ./mlx-llama-8b-4bit \
--quantize \
--q-bits 4 \
--q-group-size 64
# Convert with 8-bit quantization
mlx_lm.convert \
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
--mlx-path ./mlx-mistral-8bit \
--quantize \
--q-bits 8
Ajustement fin avec LoRA
# Fine-tune with LoRA
mlx_lm.lora \
--model mlx-community/Llama-3.1-8B-Instruct-4bit \
--data ./training_data \
--train \
--iters 1000 \
--batch-size 4 \
--lora-layers 16 \
--learning-rate 1e-5
# Resume training from checkpoint
mlx_lm.lora \
--model mlx-community/Llama-3.1-8B-Instruct-4bit \
--data ./training_data \
--train \
--resume-adapter-file ./adapters/adapters.safetensors \
--iters 500
# Evaluate after training
mlx_lm.lora \
--model mlx-community/Llama-3.1-8B-Instruct-4bit \
--data ./training_data \
--adapter-path ./adapters \
--test
Fusionner les adaptateurs LoRA
# Merge LoRA adapters into base model
mlx_lm.fuse \
--model mlx-community/Llama-3.1-8B-Instruct-4bit \
--adapter-path ./adapters \
--save-path ./fused-model
# Fuse and re-quantize
mlx_lm.fuse \
--model mlx-community/Llama-3.1-8B-Instruct-4bit \
--adapter-path ./adapters \
--save-path ./fused-model-4bit \
--de-quantize
Format des données d’entraînement
{"text": "Below is an instruction.\n\n### Instruction:\nExplain gravity.\n\n### Response:\nGravity is a fundamental force..."}
{"text": "Below is an instruction.\n\n### Instruction:\nWhat is DNA?\n\n### Response:\nDNA is a molecule..."}
Place data files as train.jsonl, valid.jsonl, and test.jsonl in your data directory.
API core MLX
import mlx.core as mx
# Array creation (like NumPy)
a = mx.array([1.0, 2.0, 3.0])
b = mx.zeros((3, 4))
c = mx.ones((2, 3), dtype=mx.float16)
d = mx.random.normal((4, 4))
# Device placement (unified memory - no explicit transfers)
x = mx.array([1.0, 2.0, 3.0]) # Available on both CPU and GPU
# Basic operations
result = mx.matmul(a.reshape(1, -1), mx.ones((3, 4)))
y = mx.exp(x) + mx.sin(x)
# Lazy evaluation - computations only run when needed
z = mx.add(x, x)
z = mx.multiply(z, 2.0)
mx.eval(z) # Triggers computation
# Automatic differentiation
def loss_fn(x):
return mx.sum(x ** 2)
grad_fn = mx.grad(loss_fn)
grads = grad_fn(mx.array([1.0, 2.0, 3.0]))
Module de réseau de neurones MLX
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# Define a model
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.layers = [
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
model = MLP(784, 256, 10)
# Optimizer
optimizer = optim.Adam(learning_rate=1e-3)
# Training step with value_and_grad
def loss_fn(model, x, y):
logits = model(x)
return mx.mean(nn.losses.cross_entropy(logits, y))
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Training loop
for batch_x, batch_y in dataloader:
loss, grads = loss_and_grad_fn(model, batch_x, batch_y)
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state)
Mode serveur
# Start OpenAI-compatible server
mlx_lm.server \
--model mlx-community/Llama-3.1-8B-Instruct-4bit \
--port 8080
# Query the server
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "default",
"messages": [{"role": "user", "content": "Hello!"}],
"max_tokens": 100
}'
Conseils de performance
# Use mx.compile for repeated operations
@mx.compile
def fast_forward(model, x):
return model(x)
# Use float16 for faster inference
model, tokenizer = load("model-path", dtype=mx.float16)
# Batch processing
inputs = mx.array(batch_of_inputs)
outputs = model(inputs) # Processes entire batch at once
mx.eval(outputs)
Comparaison MLX vs PyTorch
| Fonctionnalité | MLX | PyTorch |
|---|---|---|
| Memory model | Unified (shared CPU/GPU) | Explicit transfers |
| Evaluation | Lazy (deferred) | Eager (immediate) |
| Platform | Apple Silicon only | Cross-platform |
| Array API | NumPy-like | NumPy-like |
| Auto-diff | mx.grad | torch.autograd |
| Compilation | mx.compile | torch.compile |
Commandes courantes
| Task | Commande |
|---|---|
| Generate text | mlx_lm.generate --model MODEL --prompt TEXT |
| Interactive chat | mlx_lm.chat --model MODEL |
| Convert model | mlx_lm.convert --hf-path HF_MODEL --mlx-path OUTPUT |
| Quantize model | mlx_lm.convert --hf-path MODEL --quantize --q-bits 4 |
| LoRA fine-tune | mlx_lm.lora --model MODEL --data DIR --train |
| Fuse adapters | mlx_lm.fuse --model MODEL --adapter-path DIR |
| Start server | mlx_lm.server --model MODEL --port 8080 |