You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: .exp/design-workflow-1-grok-1-inference-and-sampling.md
+53-6Lines changed: 53 additions & 6 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -4,10 +4,12 @@
4
4
5
5
The \"Grok-1 Inference and Sampling\" workflow provides the machinery to load the Grok-1 model's 314 billion parameters from a checkpoint, initialize the decoder-only transformer architecture with Mixture-of-Experts (MoE) layers and Grouped Query Attention (GQA), set up distributed sharding across GPUs using JAX meshes and PJIT, tokenize prompts with SentencePiece, and generate text autoregressively. Sampling incorporates temperature-controlled softmax, nucleus (top-p) filtering for diversity control, and top-k logging. The design emphasizes correctness for validation, supporting batched multi-request handling via a generator that manages KV caches per request slot, padding for variable lengths, and efficient decode steps post-prefill.
6
6
7
-
Key inputs: Checkpoint in `./checkpoints/ckpt-0/`, `tokenizer.model`, GPU cluster, prompts as `Request` objects (prompt str, temperature float, nucleus_p float, rng_seed int, max_len int).
7
+
Key inputs: Checkpoint in `./checkpoints/ckpt-0/` (automated download and symlink via `just download-weights` in Justfile), `tokenizer.model`, GPU cluster, prompts as `Request` objects (prompt str, temperature float, nucleus_p float, rng_seed int, max_len int).
8
8
Outputs: Generated text strings.
9
-
Entry points: `run.py` for test run, or `InferenceRunner().run()` generator for streaming requests.
Entry points: `just test` (runs `run.py`) for test run, or `runners.InferenceRunner().run()` generator for streaming requests. Use `nix develop` from `flake.nix` for dev environment setup including deps and tools.
Development and setup files: `Justfile` (tasks for download and test), `flake.nix` (Nix dev shell), `.envrc` (direnv integration), `.env.public` (magnet link), `requirements.txt` (Python deps), `.github/hooks/pre-commit` (ruff pre-commit), `.github/workflows/test.yml` (CI linting).
11
13
12
14
The workflow orchestrates model loading, compilation of sharded compute functions, prompt processing (prefill KV cache while sampling first token), and iterative single-token generation using cached attention keys/values, until max length or EOS.
13
15
@@ -45,25 +47,28 @@ The workflow orchestrates model loading, compilation of sharded compute function
Note over Dev,Ruff: GitHub CI test.yml runs ruff on PRs/push
145
+
```
146
+
107
147
## Sharding and Distributed Execution
108
148
109
149
-**Mesh Configuration**: `make_mesh(local=(data_replicas, model_par), between_hosts=(data_hosts, model_hosts))` creates hybrid mesh for SPMD parallelism. E.g., local 1x8 shards model across 8 GPUs.
@@ -127,6 +167,13 @@ sequenceDiagram
127
167
-**Error/Edge Cases**: Assumes sufficient memory/GPUs; handles long contexts by left-truncation/padding. No built-in EOS handling (relies on max_len or app logic). Quantized weights require custom unpickling.
128
168
-**Performance Notes**: MoE router/experts use JAX vmap/shard_map (serial per-token, inefficient for prod). Focus on correctness/single-host validation.
- Reproducible dev environment: `flake.nix` enables `nix develop` shell that auto-creates `.venv`, installs requirements via pip, tools (just, transmission for torrent, ruff linter), and sets git core.hooksPath to `.github/hooks` for pre-commit ruff checks.
173
+
- Direnv: `.envrc` for automatic nix flake and python layout activation, loads `.env.public` containing `GROK_MAGNET_LINK`.
174
+
- Checkpoint download: ~314B weights via torrent (magnet URI) or Hugging Face to `./checkpoints/`. Place/ symlink as `ckpt-0`. Automated with `just download-weights` (uses transmission-cli, creates dir/symlink).
175
+
- Testing: `just test` runs `python run.py`; or direct execution.
176
+
- Quality control: Ruff linting in local pre-commit and GitHub Actions CI (`.github/workflows/test.yml`) on PRs and pushes to main.
177
+
- Concerns: Large download untested in PR due to size; requires stable internet and sufficient disk space.
131
178
132
179
This document captures the high-level design, derived from code analysis.
-**Grok-1 Inference and Sampling** (Workflow 1): This workflow is impacted by the PR's enhancements to development setup and checkpoint acquisition processes. Evidence from PR description and changed files shows additions of Nix-based reproducible environments (flake.nix, .envrc), task automation for downloading model weights via torrent (Justfile, .env.public, transmission), testing (just test running run.py entry point), and linting enforcement (pre-commit, test.yml). These align with and expand the design doc's \"Dependencies & Setup\" section mentioning torrent/HF download to checkpoints/ckpt-0/, a key input. Core runtime flows unchanged. [PR #286](https://github.com/xai-org/grok-1/pull/286)
5
+
6
+
Workflows 2 (Model Loading) and 3 (Forward Pass) unaffected, lacking setup references in docs and no PR changes to their core files/logic.
7
+
8
+
## Workflow 1 Analysis
9
+
### Summary of design changes
10
+
Specific aspects affected: Prerequisites for workflow execution, including environment configuration and model data preparation. The PR adds a structured dev setup layer before user invocation of run.py.
11
+
12
+
Implementation:
13
+
- Deterministic deps via nixpkgs in flake.nix, with shellHook automating venv, pip installs, git hooks setup.
14
+
- Auto-activation via direnv (.envrc loading .env.public magnet).
15
+
- Tasks in Justfile: download-weights (torrent download + symlink), test (run.py).
16
+
- Quality: ruff integration locally and in CI for PR validation.
17
+
18
+
Benefits: Reproducibility across systems, reduced manual steps for large downloads/setup, enforced standards preventing bugs. Implications: Easier collaboration/onboarding; potential Nix learning curve; untested download task (per PR concerns).
19
+
20
+
The design docs have been updated in .exp/design-workflow-1-grok-1-inference-and-sampling.md to reflect these changes, including new/updated diagrams and sections.
21
+
22
+
### Updated Diagrams Showing Changes
23
+
**Initialization Sequence (with additions)**: New green-indicated steps/participant for setup phase; yellow for modified user-run.py interaction; no red removals.
0 commit comments