Skip to content

Commit f62fe14

Browse files
Update design for PR xai-org#286: Reproducible and automatically configure development environments
1 parent 4f45f05 commit f62fe14

File tree

2 files changed

+133
-6
lines changed

2 files changed

+133
-6
lines changed

.exp/design-workflow-1-grok-1-inference-and-sampling.md

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
The \"Grok-1 Inference and Sampling\" workflow provides the machinery to load the Grok-1 model's 314 billion parameters from a checkpoint, initialize the decoder-only transformer architecture with Mixture-of-Experts (MoE) layers and Grouped Query Attention (GQA), set up distributed sharding across GPUs using JAX meshes and PJIT, tokenize prompts with SentencePiece, and generate text autoregressively. Sampling incorporates temperature-controlled softmax, nucleus (top-p) filtering for diversity control, and top-k logging. The design emphasizes correctness for validation, supporting batched multi-request handling via a generator that manages KV caches per request slot, padding for variable lengths, and efficient decode steps post-prefill.
66

7-
Key inputs: Checkpoint in `./checkpoints/ckpt-0/`, `tokenizer.model`, GPU cluster, prompts as `Request` objects (prompt str, temperature float, nucleus_p float, rng_seed int, max_len int).
7+
Key inputs: Checkpoint in `./checkpoints/ckpt-0/` (automated download and symlink via `just download-weights` in Justfile), `tokenizer.model`, GPU cluster, prompts as `Request` objects (prompt str, temperature float, nucleus_p float, rng_seed int, max_len int).
88
Outputs: Generated text strings.
9-
Entry points: `run.py` for test run, or `InferenceRunner().run()` generator for streaming requests.
10-
Relevant files: `run.py`, `runners.py`, `model.py`, `checkpoint.py`, `tokenizer.model`.
9+
Entry points: `just test` (runs `run.py`) for test run, or `runners.InferenceRunner().run()` generator for streaming requests. Use `nix develop` from `flake.nix` for dev environment setup including deps and tools.
10+
Relevant files (core): `run.py`, `runners.py`, `model.py`, `checkpoint.py`, `tokenizer.model`.
11+
12+
Development and setup files: `Justfile` (tasks for download and test), `flake.nix` (Nix dev shell), `.envrc` (direnv integration), `.env.public` (magnet link), `requirements.txt` (Python deps), `.github/hooks/pre-commit` (ruff pre-commit), `.github/workflows/test.yml` (CI linting).
1113

1214
The workflow orchestrates model loading, compilation of sharded compute functions, prompt processing (prefill KV cache while sampling first token), and iterative single-token generation using cached attention keys/values, until max length or EOS.
1315

@@ -45,25 +47,28 @@ The workflow orchestrates model loading, compilation of sharded compute function
4547

4648
```mermaid
4749
sequenceDiagram
50+
participant Setup as "Dev Setup (Addition)"
4851
participant User
4952
participant RunPy as run.py
5053
participant IR as InferenceRunner
5154
participant MR as ModelRunner
5255
participant Model as model.py
5356
participant Checkpoint as checkpoint.py
5457
participant JAX as JAX Runtime
55-
User->>RunPy: Execute main()
58+
Setup->>User: nix develop<br/>direnv allow<br/>just download-weights<br/>(reproducible env, deps install, checkpoint torrent download & symlink)
59+
User->>RunPy: Execute main()<br/>or just test
5660
RunPy->>IR: Create with config, MR, paths, meshes
5761
IR->>MR: initialize(dummy_data, meshes)
5862
MR->>Model: model.initialize(), fprop_dtype=bf16
5963
Note over MR,JAX: Calculate batch sizes, create mesh (data, model axes)
6064
MR->>MR: hk.transform forward/logits_fn with pjit sharding
6165
MR->>Checkpoint: load_or_init -> restore(shapes, mesh, sharding)
6266
Checkpoint->>MR: Sharded params (TrainingState)
63-
IR->>IR: Load tokenizer, compile pjit funcs (sample_step, prefill_memory, new_memory) with shardings
67+
IR->>IR: Load tokenizer<br/>compile pjit funcs (sample_step, prefill_memory, new_memory)<br/>with shardings
6468
IR->>IR: Precompile with dummy prompts for pad_sizes
6569
RunPy->>IR: gen = run() // generator setup with initial memory, settings, etc.
6670
```
71+
Note: New "Dev Setup" participant and steps reflect PR #286 additions for environment and data preparation. Core sequence unchanged.
6772

6873
## Inference and Sampling Sequence
6974

@@ -104,6 +109,41 @@ sequenceDiagram
104109
end
105110
```
106111

112+
## Development Environment and Setup Sequence
113+
114+
PR #286 adds infrastructure for reproducible dev environments and automated setup, streamlining preparation for this workflow.
115+
116+
### Setup Sequence
117+
118+
```mermaid
119+
sequenceDiagram
120+
participant Dev as Developer
121+
participant Nix as Nix Flake
122+
participant Direnv as Direnv (.envrc)
123+
participant Env as .env.public
124+
participant Just as Justfile
125+
participant Transmission as Transmission CLI
126+
participant Checkpoints as checkpoints/
127+
participant GitHooks as Git Hooks
128+
participant Ruff as Ruff Linter
129+
participant Python as Python Venv
130+
Dev->>Nix: nix develop or direnv allow
131+
Nix->>Direnv: source .envrc (use flake, python layout)
132+
Direnv->>Env: load GROK_MAGNET_LINK
133+
Nix->>Just: install just
134+
Nix->>Transmission: install transmission
135+
Nix->>Ruff: install ruff
136+
Nix->>Python: create .venv, pip install requirements.txt
137+
Nix->>GitHooks: git config core.hooksPath .github/hooks
138+
GitHooks->>Ruff: pre-commit runs ruff check
139+
Dev->>Just: just download-weights
140+
Just->>Transmission: transmission-cli --download-dir checkpoints $GROK_MAGNET_LINK
141+
Transmission->>Checkpoints: download grok-1/ckpt-0/
142+
Just->>Checkpoints: ln -s grok-1/ckpt-0 ckpt-0
143+
Dev->>Just: just test (runs python run.py)
144+
Note over Dev,Ruff: GitHub CI test.yml runs ruff on PRs/push
145+
```
146+
107147
## Sharding and Distributed Execution
108148

109149
- **Mesh Configuration**: `make_mesh(local=(data_replicas, model_par), between_hosts=(data_hosts, model_hosts))` creates hybrid mesh for SPMD parallelism. E.g., local 1x8 shards model across 8 GPUs.
@@ -127,6 +167,13 @@ sequenceDiagram
127167
- **Error/Edge Cases**: Assumes sufficient memory/GPUs; handles long contexts by left-truncation/padding. No built-in EOS handling (relies on max_len or app logic). Quantized weights require custom unpickling.
128168
- **Performance Notes**: MoE router/experts use JAX vmap/shard_map (serial per-token, inefficient for prod). Focus on correctness/single-host validation.
129169
- **Extensibility**: Modular Haiku design allows custom configs/modules. Generator interface suits serving multiple prompts concurrently.
130-
- **Dependencies & Setup**: `requirements.txt` (jax[cuda12_pip], haiku, etc.). Download ckpt via torrent/HF, place in checkpoints/.
170+
- **Dependencies & Setup**:
171+
- Python dependencies: `requirements.txt` (jax[cuda12_pip], haiku, sentencepiece, numpy, etc.).
172+
- Reproducible dev environment: `flake.nix` enables `nix develop` shell that auto-creates `.venv`, installs requirements via pip, tools (just, transmission for torrent, ruff linter), and sets git core.hooksPath to `.github/hooks` for pre-commit ruff checks.
173+
- Direnv: `.envrc` for automatic nix flake and python layout activation, loads `.env.public` containing `GROK_MAGNET_LINK`.
174+
- Checkpoint download: ~314B weights via torrent (magnet URI) or Hugging Face to `./checkpoints/`. Place/ symlink as `ckpt-0`. Automated with `just download-weights` (uses transmission-cli, creates dir/symlink).
175+
- Testing: `just test` runs `python run.py`; or direct execution.
176+
- Quality control: Ruff linting in local pre-commit and GitHub Actions CI (`.github/workflows/test.yml`) on PRs and pushes to main.
177+
- Concerns: Large download untested in PR due to size; requires stable internet and sufficient disk space.
131178

132179
This document captures the high-level design, derived from code analysis.

pr-analysis-286.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# PR #286: Workflow Design Impact Analysis
2+
3+
## Affected Workflows
4+
- **Grok-1 Inference and Sampling** (Workflow 1): This workflow is impacted by the PR's enhancements to development setup and checkpoint acquisition processes. Evidence from PR description and changed files shows additions of Nix-based reproducible environments (flake.nix, .envrc), task automation for downloading model weights via torrent (Justfile, .env.public, transmission), testing (just test running run.py entry point), and linting enforcement (pre-commit, test.yml). These align with and expand the design doc's \"Dependencies & Setup\" section mentioning torrent/HF download to checkpoints/ckpt-0/, a key input. Core runtime flows unchanged. [PR #286](https://github.com/xai-org/grok-1/pull/286)
5+
6+
Workflows 2 (Model Loading) and 3 (Forward Pass) unaffected, lacking setup references in docs and no PR changes to their core files/logic.
7+
8+
## Workflow 1 Analysis
9+
### Summary of design changes
10+
Specific aspects affected: Prerequisites for workflow execution, including environment configuration and model data preparation. The PR adds a structured dev setup layer before user invocation of run.py.
11+
12+
Implementation:
13+
- Deterministic deps via nixpkgs in flake.nix, with shellHook automating venv, pip installs, git hooks setup.
14+
- Auto-activation via direnv (.envrc loading .env.public magnet).
15+
- Tasks in Justfile: download-weights (torrent download + symlink), test (run.py).
16+
- Quality: ruff integration locally and in CI for PR validation.
17+
18+
Benefits: Reproducibility across systems, reduced manual steps for large downloads/setup, enforced standards preventing bugs. Implications: Easier collaboration/onboarding; potential Nix learning curve; untested download task (per PR concerns).
19+
20+
The design docs have been updated in .exp/design-workflow-1-grok-1-inference-and-sampling.md to reflect these changes, including new/updated diagrams and sections.
21+
22+
### Updated Diagrams Showing Changes
23+
**Initialization Sequence (with additions)**: New green-indicated steps/participant for setup phase; yellow for modified user-run.py interaction; no red removals.
24+
25+
```mermaid
26+
sequenceDiagram
27+
participant Setup as "Dev Setup (Addition)"
28+
participant User
29+
participant RunPy as run.py
30+
participant IR as InferenceRunner
31+
participant MR as ModelRunner
32+
participant Model as model.py
33+
participant Checkpoint as checkpoint.py
34+
participant JAX as JAX Runtime
35+
Setup->>User: nix develop<br/>direnv allow<br/>just download-weights<br/>(reproducible env, deps install, checkpoint torrent download & symlink)
36+
User->>RunPy: Execute main()<br/>or just test
37+
RunPy->>IR: Create with config, MR, paths, meshes
38+
IR->>MR: initialize(dummy_data, meshes)
39+
MR->>Model: model.initialize(), fprop_dtype=bf16
40+
Note over MR,JAX: Calculate batch sizes, create mesh (data, model axes)
41+
MR->>MR: hk.transform forward/logits_fn with pjit sharding
42+
MR->>Checkpoint: load_or_init -> restore(shapes, mesh, sharding)
43+
Checkpoint->>MR: Sharded params (TrainingState)
44+
IR->>IR: Load tokenizer<br/>compile pjit funcs (sample_step, prefill_memory, new_memory)<br/>with shardings
45+
IR->>IR: Precompile with dummy prompts for pad_sizes
46+
RunPy->>IR: gen = run() // generator setup with initial memory, settings, etc.
47+
```
48+
49+
**New Setup Sequence Diagram** (additions only, green by nature):
50+
51+
```mermaid
52+
sequenceDiagram
53+
participant Dev as Developer
54+
participant Nix as Nix Flake
55+
participant Direnv as Direnv (.envrc)
56+
participant Env as .env.public
57+
participant Just as Justfile
58+
participant Transmission as Transmission CLI
59+
participant Checkpoints as checkpoints/
60+
participant GitHooks as Git Hooks
61+
participant Ruff as Ruff Linter
62+
participant Python as Python Venv
63+
Dev->>Nix: nix develop or direnv allow
64+
Nix->>Direnv: source .envrc (use flake, python layout)
65+
Direnv->>Env: load GROK_MAGNET_LINK
66+
Nix->>Just: install just
67+
Nix->>Transmission: install transmission
68+
Nix->>Ruff: install ruff
69+
Nix->>Python: create .venv, pip install requirements.txt
70+
Nix->>GitHooks: git config core.hooksPath .github/hooks
71+
GitHooks->>Ruff: pre-commit runs ruff check
72+
Dev->>Just: just download-weights
73+
Just->>Transmission: transmission-cli --download-dir checkpoints $GROK_MAGNET_LINK
74+
Transmission->>Checkpoints: download grok-1/ckpt-0/
75+
Just->>Checkpoints: ln -s grok-1/ckpt-0 ckpt-0
76+
Dev->>Just: just test (runs python run.py)
77+
Note over Dev,Ruff: GitHub CI test.yml runs ruff on PRs/push
78+
```
79+
80+
The Inference and Sampling diagram remains unchanged, as PR does not affect sampling flows.

0 commit comments

Comments
 (0)